diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index ec6b261df4..e5767891a1 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -7394,31 +7394,37 @@ defaultdict(, {'col..., 'col...})] if col in values: item = values[col] item = item.tolist() if isinstance(item, np.ndarray) else list(item) - data_spark_columns.append( - self._internal.spark_column_for(self._internal.column_labels[i]) - .isin(item) - .alias(self._internal.data_spark_column_names[i]) + + scol = self._internal.spark_column_for(self._internal.column_labels[i]).isin( + [SF.lit(v) for v in item] ) + scol = F.coalesce(scol, F.lit(False)) else: - data_spark_columns.append( - SF.lit(False).alias(self._internal.data_spark_column_names[i]) - ) + scol = SF.lit(False) + data_spark_columns.append(scol.alias(self._internal.data_spark_column_names[i])) elif is_list_like(values): values = ( cast(np.ndarray, values).tolist() if isinstance(values, np.ndarray) else list(values) ) - data_spark_columns += [ - self._internal.spark_column_for(label) - .isin(values) - .alias(self._internal.spark_column_name_for(label)) - for label in self._internal.column_labels - ] + + for label in self._internal.column_labels: + scol = self._internal.spark_column_for(label).isin([SF.lit(v) for v in values]) + scol = F.coalesce(scol, F.lit(False)) + data_spark_columns.append(scol.alias(self._internal.spark_column_name_for(label))) else: raise TypeError("Values should be iterable, Series, DataFrame or dict.") - return DataFrame(self._internal.with_new_columns(data_spark_columns)) + return DataFrame( + self._internal.with_new_columns( + data_spark_columns, + data_fields=[ + field.copy(dtype=np.dtype("bool"), spark_type=BooleanType(), nullable=False) + for field in self._internal.data_fields + ], + ) + ) @property def shape(self) -> Tuple[int, int]: diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 11da18cfa2..27c670026d 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -1954,6 +1954,41 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(TypeError, msg): psdf.isin(1) + pdf = pd.DataFrame( + { + "a": [4, 2, 3, 4, 8, 6], + "b": [1, None, 9, 4, None, 4], + "c": [None, 5, None, 3, 2, 1], + }, + ) + psdf = ps.from_pandas(pdf) + + if LooseVersion(pd.__version__) >= LooseVersion("1.2"): + self.assert_eq(psdf.isin([4, 3, 1, 1, None]), pdf.isin([4, 3, 1, 1, None])) + else: + expected = pd.DataFrame( + { + "a": [True, False, True, True, False, False], + "b": [True, False, False, True, False, True], + "c": [False, False, False, True, False, True], + } + ) + self.assert_eq(psdf.isin([4, 3, 1, 1, None]), expected) + + if LooseVersion(pd.__version__) >= LooseVersion("1.2"): + self.assert_eq( + psdf.isin({"b": [4, 3, 1, 1, None]}), pdf.isin({"b": [4, 3, 1, 1, None]}) + ) + else: + expected = pd.DataFrame( + { + "a": [False, False, False, False, False, False], + "b": [True, False, False, True, False, True], + "c": [False, False, False, False, False, False], + } + ) + self.assert_eq(psdf.isin({"b": [4, 3, 1, 1, None]}), expected) + def test_merge(self): left_pdf = pd.DataFrame( {