diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 7e796c69dc..4168b6712b 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -240,6 +240,7 @@ MultiIndex Properties MultiIndex.nlevels MultiIndex.levshape MultiIndex.values + MultiIndex.dtypes MultiIndex components ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index cff3e2689b..896ea2af27 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -375,6 +375,35 @@ class MultiIndex(Index): def name(self, name: Name) -> None: raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name") + @property + def dtypes(self) -> pd.Series: + """Return the dtypes as a Series for the underlying MultiIndex. + + .. versionadded:: 3.3.0 + + Returns + ------- + pd.Series + The data type of each level. + + Examples + -------- + >>> psmidx = ps.MultiIndex.from_arrays( + ... [[0, 1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8, 9]], + ... names=("zero", "one"), + ... ) + >>> psmidx.dtypes + zero int64 + one int64 + dtype: object + """ + return pd.Series( + [field.dtype for field in self._internal.index_fields], + index=pd.Index( + [name if len(name) > 1 else name[0] for name in self._internal.index_names] + ), + ) + def _verify_for_rename(self, name: List[Name]) -> List[Label]: # type: ignore[override] if is_list_like(name): if self._internal.index_level != len(name): diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 1ae009cc56..800fa46ce0 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -6000,6 +6000,20 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): expected_pdf = pd.DataFrame({"A": [None, 0], "B": [4.0, 1.0], "C": [3, 3]}) self.assert_eq(expected_pdf, psdf1.combine_first(psdf2)) + def test_multi_index_dtypes(self): + # SPARK-36930: Support ps.MultiIndex.dtypes + arrays = [[1, 1, 2, 2], ["red", "blue", "red", "blue"]] + pmidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color")) + psmidx = ps.from_pandas(pmidx) + + self.assert_eq(psmidx.dtypes, pmidx.dtypes) + + # multiple labels + pmidx = pd.MultiIndex.from_arrays(arrays, names=[("zero", "first"), ("one", "second")]) + psmidx = ps.from_pandas(pmidx) + + self.assert_eq(psmidx.dtypes, pmidx.dtypes) + if __name__ == "__main__": from pyspark.pandas.tests.test_dataframe import * # noqa: F401