diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 677d80ff8b..9d53f00225 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -64,6 +64,7 @@ Modifying and computations Index.drop_duplicates Index.min Index.max + Index.map Index.rename Index.repeat Index.take diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 6c842bc164..a43a5d1628 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -16,7 +16,7 @@ # from functools import partial -from typing import Any, Iterator, List, Optional, Tuple, Union, cast, no_type_check +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union, cast, no_type_check import warnings import pandas as pd @@ -521,6 +521,50 @@ class Index(IndexOpsMixin): result = result.copy() return result + def map( + self, mapper: Union[dict, Callable[[Any], Any], pd.Series], na_action: Optional[str] = None + ) -> "Index": + """ + Map values using input correspondence (a dict, Series, or function). + + Parameters + ---------- + mapper : function, dict, or pd.Series + Mapping correspondence. + na_action : {None, 'ignore'} + If ‘ignore’, propagate NA values, without passing them to the mapping correspondence. + + Returns + ------- + applied : Index, inferred + The output of the mapping function applied to the index. + + Examples + -------- + >>> psidx = ps.Index([1, 2, 3]) + + >>> psidx.map({1: "one", 2: "two", 3: "three"}) + Index(['one', 'two', 'three'], dtype='object') + + >>> psidx.map(lambda id: "{id} + 1".format(id=id)) + Index(['1 + 1', '2 + 1', '3 + 1'], dtype='object') + + >>> pser = pd.Series(["one", "two", "three"], index=[1, 2, 3]) + >>> psidx.map(pser) + Index(['one', 'two', 'three'], dtype='object') + """ + if isinstance(mapper, dict): + if len(set(type(k) for k in mapper.values())) > 1: + raise TypeError( + "If the mapper is a dictionary, its values must be of the same type" + ) + + return Index( + self.to_series().pandas_on_spark.transform_batch( + lambda pser: pser.map(mapper, na_action) + ) + ).rename(self.name) + @property def values(self) -> np.ndarray: """ diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index e2dbd33747..193c12697d 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -642,6 +642,13 @@ class CategoricalIndex(Index): return partial(property_or_func, self) raise AttributeError("'CategoricalIndex' object has no attribute '{}'".format(item)) + def map( + self, + mapper: Union[dict, Callable[[Any], Any], pd.Series] = None, + na_action: Optional[str] = None, + ) -> "Index": + return MissingPandasLikeCategoricalIndex.map(self, mapper, na_action) + def _test() -> None: import os diff --git a/python/pyspark/pandas/indexes/datetimes.py b/python/pyspark/pandas/indexes/datetimes.py index 6998adf99d..691d8f9592 100644 --- a/python/pyspark/pandas/indexes/datetimes.py +++ b/python/pyspark/pandas/indexes/datetimes.py @@ -16,7 +16,7 @@ # import datetime from functools import partial -from typing import Any, Optional, Union, cast, no_type_check +from typing import Any, Callable, Optional, Union, cast, no_type_check import pandas as pd from pandas.api.types import is_hashable @@ -741,6 +741,13 @@ class DatetimeIndex(Index): psdf = psdf.pandas_on_spark.apply_batch(pandas_at_time) return ps.Index(first_series(psdf).rename(self.name)) + def map( + self, + mapper: Union[dict, Callable[[Any], Any], pd.Series] = None, + na_action: Optional[str] = None, + ) -> "Index": + return MissingPandasLikeDatetimeIndex.map(self, mapper, na_action) + def disallow_nanoseconds(freq: Union[str, DateOffset]) -> None: if freq in ["N", "ns"]: diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 4b5ec044ff..fb0208099f 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -1165,6 +1165,13 @@ class MultiIndex(Index): def __iter__(self) -> Iterator: return MissingPandasLikeMultiIndex.__iter__(self) + def map( + self, + mapper: Union[dict, Callable[[Any], Any], pd.Series] = None, + na_action: Optional[str] = None, + ) -> "Index": + return MissingPandasLikeMultiIndex.map(self, mapper, na_action) + def _test() -> None: import os diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py index 938aea2629..90e0c3e2bc 100644 --- a/python/pyspark/pandas/missing/indexes.py +++ b/python/pyspark/pandas/missing/indexes.py @@ -58,7 +58,6 @@ class MissingPandasLikeIndex(object): is_ = _unsupported_function("is_") is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple") join = _unsupported_function("join") - map = _unsupported_function("map") putmask = _unsupported_function("putmask") ravel = _unsupported_function("ravel") reindex = _unsupported_function("reindex") @@ -118,6 +117,7 @@ class MissingPandasLikeDatetimeIndex(MissingPandasLikeIndex): to_pydatetime = _unsupported_function("to_pydatetime", cls="DatetimeIndex") mean = _unsupported_function("mean", cls="DatetimeIndex") std = _unsupported_function("std", cls="DatetimeIndex") + map = _unsupported_function("map", cls="DatetimeIndex") class MissingPandasLikeCategoricalIndex(MissingPandasLikeIndex): diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 65831d1866..fb1e4b2e5f 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -2319,6 +2319,80 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils): self.assertRaises(PandasNotImplementedError, lambda: psmidx.factorize()) + def test_map(self): + pidx = pd.Index([1, 2, 3]) + psidx = ps.from_pandas(pidx) + + # Apply dict + self.assert_eq( + pidx.map({1: "one", 2: "two", 3: "three"}), + psidx.map({1: "one", 2: "two", 3: "three"}), + ) + self.assert_eq( + pidx.map({1: "one", 2: "two"}), + psidx.map({1: "one", 2: "two"}), + ) + self.assert_eq( + pidx.map({1: "one", 2: "two"}, na_action="ignore"), + psidx.map({1: "one", 2: "two"}, na_action="ignore"), + ) + self.assert_eq( + pidx.map({1: 10, 2: 20}), + psidx.map({1: 10, 2: 20}), + ) + self.assert_eq( + (pidx + 1).map({1: 10, 2: 20}), + (psidx + 1).map({1: 10, 2: 20}), + ) + + # Apply lambda + self.assert_eq( + pidx.map(lambda id: id + 1), + psidx.map(lambda id: id + 1), + ) + self.assert_eq( + pidx.map(lambda id: id + 1.1), + psidx.map(lambda id: id + 1.1), + ) + self.assert_eq( + pidx.map(lambda id: "{id} + 1".format(id=id)), + psidx.map(lambda id: "{id} + 1".format(id=id)), + ) + self.assert_eq( + (pidx + 1).map(lambda id: "{id} + 1".format(id=id)), + (psidx + 1).map(lambda id: "{id} + 1".format(id=id)), + ) + + # Apply series + pser = pd.Series(["one", "two", "three"], index=[1, 2, 3]) + self.assert_eq( + pidx.map(pser), + psidx.map(pser), + ) + pser = pd.Series(["one", "two", "three"]) + self.assert_eq( + pidx.map(pser), + psidx.map(pser), + ) + self.assert_eq( + pidx.map(pser, na_action="ignore"), + psidx.map(pser, na_action="ignore"), + ) + pser = pd.Series([1, 2, 3]) + self.assert_eq( + pidx.map(pser), + psidx.map(pser), + ) + self.assert_eq( + (pidx + 1).map(pser), + (psidx + 1).map(pser), + ) + + self.assertRaises( + TypeError, + lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}), + ) + if __name__ == "__main__": from pyspark.pandas.tests.indexes.test_base import * # noqa: F401