[SPARK-35738][PYTHON] Support 'y' properly in DataFrame with non-numeric columns with plots
### What changes were proposed in this pull request? This PR proposes to port the fix https://github.com/databricks/koalas/pull/2172. ```python ks.DataFrame({'a': [1, 2, 3], 'b':["a", "b", "c"], 'c': [4, 5, 6]}).plot(kind='hist', x='a', y='c', bins=200) ``` **Before:** ``` pyspark.sql.utils.AnalysisException: cannot resolve 'least(min(a), min(b), min(c))' due to data type mismatch: The expressions should all have the same type, got LEAST(bigint, string, bigint).; 'Aggregate [unresolvedalias(least(min(a#1L), min(b#2), min(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1)), unresolvedalias(greatest(max(a#1L), max(b#2), max(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1))] +- Project [a#1L, b#2, c#3L] +- Project [__index_level_0__#0L, a#1L, b#2, c#3L, monotonically_increasing_id() AS __natural_order__#8L] +- LogicalRDD [__index_level_0__#0L, a#1L, b#2, c#3L], false ``` **After:** ```python Figure({ 'data': [{'hovertemplate': 'variable=a<br>value=%{text}<br>count=%{y}', 'name': 'a', ... ``` ### Why are the changes needed? To match the behaviour with panadas' and allow users to set `x` and `y` in the DataFrame with non-numeric columns. ### Does this PR introduce _any_ user-facing change? No to end users since the changes is not released yet. Yes to dev as described before. ### How was this patch tested? Manually tested, added a test and tested in notebooks: ![Screen Shot 2021-06-11 at 9 11 25 PM](https://user-images.githubusercontent.com/6477701/121686038-a47a1b80-cafb-11eb-8f8e-8d968db7ebef.png) ![Screen Shot 2021-06-11 at 9 48 58 PM](https://user-images.githubusercontent.com/6477701/121688858-e22c7380-cafe-11eb-9d0a-adcbe560030f.png) Closes #32884 from HyukjinKwon/fix-hist-plot. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
9c7250fa73
commit
76e08a8e3d
|
@ -73,8 +73,14 @@ def plot_pie(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
|
|||
|
||||
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
|
||||
import plotly.graph_objs as go
|
||||
import pyspark.pandas as ps
|
||||
|
||||
bins = kwargs.get("bins", 10)
|
||||
y = kwargs.get("y")
|
||||
if y and isinstance(data, ps.DataFrame):
|
||||
# Note that the results here are matched with matplotlib. x and y
|
||||
# handling is different from pandas' plotly output.
|
||||
data = data[y]
|
||||
psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
|
||||
assert len(bins) > 2, "the number of buckets must be higher than 2."
|
||||
output_series = HistogramPlotBase.compute_hist(psdf, bins)
|
||||
|
|
|
@ -411,6 +411,19 @@ class DataFramePlotMatplotlibTest(PandasOnSparkTestCase, TestUtils):
|
|||
bin2 = self.plot_to_base64(ax2)
|
||||
self.assertEqual(bin1, bin2)
|
||||
|
||||
non_numeric_pdf = self.pdf1.copy()
|
||||
non_numeric_pdf.c = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]
|
||||
non_numeric_psdf = ps.from_pandas(non_numeric_pdf)
|
||||
ax1 = non_numeric_pdf.plot.hist(
|
||||
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
|
||||
)
|
||||
bin1 = self.plot_to_base64(ax1)
|
||||
ax2 = non_numeric_psdf.plot.hist(
|
||||
x=non_numeric_pdf.columns[0], y=non_numeric_pdf.columns[1], bins=3
|
||||
)
|
||||
bin2 = self.plot_to_base64(ax2)
|
||||
self.assertEqual(bin1, bin2)
|
||||
|
||||
pdf1 = self.pdf1
|
||||
psdf1 = self.psdf1
|
||||
check_hist_plot(pdf1, psdf1)
|
||||
|
|
Loading…
Reference in a new issue