[SPARK-35738][PYTHON] Support 'y' properly in DataFrame with non-numeric columns with plots

### What changes were proposed in this pull request?

This PR proposes to port the fix https://github.com/databricks/koalas/pull/2172.

```python
ks.DataFrame({'a': [1, 2, 3], 'b':["a", "b", "c"], 'c': [4, 5, 6]}).plot(kind='hist', x='a', y='c', bins=200)
```

**Before:**

```
pyspark.sql.utils.AnalysisException: cannot resolve 'least(min(a), min(b), min(c))' due to data type mismatch: The expressions should all have the same type, got LEAST(bigint, string, bigint).;
'Aggregate [unresolvedalias(least(min(a#1L), min(b#2), min(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1)), unresolvedalias(greatest(max(a#1L), max(b#2), max(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1))]
+- Project [a#1L, b#2, c#3L]
   +- Project [__index_level_0__#0L, a#1L, b#2, c#3L, monotonically_increasing_id() AS __natural_order__#8L]
      +- LogicalRDD [__index_level_0__#0L, a#1L, b#2, c#3L], false
```

**After:**

```python
Figure({
    'data': [{'hovertemplate': 'variable=a<br>value=%{text}<br>count=%{y}',
              'name': 'a',
...
```

### Why are the changes needed?

To match the behaviour with panadas' and allow users to set `x` and `y` in the DataFrame with non-numeric columns.

### Does this PR introduce _any_ user-facing change?

No to end users since the changes is not released yet. Yes to dev as described before.

### How was this patch tested?

Manually tested, added a test and tested in notebooks:

![Screen Shot 2021-06-11 at 9 11 25 PM](https://user-images.githubusercontent.com/6477701/121686038-a47a1b80-cafb-11eb-8f8e-8d968db7ebef.png)

![Screen Shot 2021-06-11 at 9 48 58 PM](https://user-images.githubusercontent.com/6477701/121688858-e22c7380-cafe-11eb-9d0a-adcbe560030f.png)

Closes #32884 from HyukjinKwon/fix-hist-plot.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Hyukjin Kwon 2021-06-12 14:36:46 +09:00
parent 9c7250fa73
commit 76e08a8e3d
2 changed files with 19 additions and 0 deletions

View file

@ -73,8 +73,14 @@ def plot_pie(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps
bins = kwargs.get("bins", 10)
y = kwargs.get("y")
if y and isinstance(data, ps.DataFrame):
# Note that the results here are matched with matplotlib. x and y
# handling is different from pandas' plotly output.
data = data[y]
psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = HistogramPlotBase.compute_hist(psdf, bins)

View file

@ -411,6 +411,19 @@ class DataFramePlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
bin2 = self.plot_to_base64(ax2)
self.assertEqual(bin1, bin2)
non_numeric_pdf = self.pdf1.copy()
non_numeric_pdf.c = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]
non_numeric_psdf = ps.from_pandas(non_numeric_pdf)
ax1 = non_numeric_pdf.plot.hist(
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
)
bin1 = self.plot_to_base64(ax1)
ax2 = non_numeric_psdf.plot.hist(
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
)
bin2 = self.plot_to_base64(ax2)
self.assertEqual(bin1, bin2)
pdf1 = self.pdf1
psdf1 = self.psdf1
check_hist_plot(pdf1, psdf1)