spark-instrumented-optimizer/python/pyspark/pandas/plot/plotly.py
Hyukjin Kwon 76e08a8e3d [SPARK-35738][PYTHON] Support 'y' properly in DataFrame with non-numeric columns with plots
### What changes were proposed in this pull request?

This PR proposes to port the fix https://github.com/databricks/koalas/pull/2172.

```python
ks.DataFrame({'a': [1, 2, 3], 'b':["a", "b", "c"], 'c': [4, 5, 6]}).plot(kind='hist', x='a', y='c', bins=200)
```

**Before:**

```
pyspark.sql.utils.AnalysisException: cannot resolve 'least(min(a), min(b), min(c))' due to data type mismatch: The expressions should all have the same type, got LEAST(bigint, string, bigint).;
'Aggregate [unresolvedalias(least(min(a#1L), min(b#2), min(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1)), unresolvedalias(greatest(max(a#1L), max(b#2), max(c#3L)), Some(org.apache.spark.sql.Column$$Lambda$1556/0x0000000800d9484042fb0cc1))]
+- Project [a#1L, b#2, c#3L]
   +- Project [__index_level_0__#0L, a#1L, b#2, c#3L, monotonically_increasing_id() AS __natural_order__#8L]
      +- LogicalRDD [__index_level_0__#0L, a#1L, b#2, c#3L], false
```

**After:**

```python
Figure({
    'data': [{'hovertemplate': 'variable=a<br>value=%{text}<br>count=%{y}',
              'name': 'a',
...
```

### Why are the changes needed?

To match the behaviour with panadas' and allow users to set `x` and `y` in the DataFrame with non-numeric columns.

### Does this PR introduce _any_ user-facing change?

No to end users since the changes is not released yet. Yes to dev as described before.

### How was this patch tested?

Manually tested, added a test and tested in notebooks:

![Screen Shot 2021-06-11 at 9 11 25 PM](https://user-images.githubusercontent.com/6477701/121686038-a47a1b80-cafb-11eb-8f8e-8d968db7ebef.png)

![Screen Shot 2021-06-11 at 9 48 58 PM](https://user-images.githubusercontent.com/6477701/121688858-e22c7380-cafe-11eb-9d0a-adcbe560030f.png)

Closes #32884 from HyukjinKwon/fix-hist-plot.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
2021-06-12 14:36:46 +09:00

219 lines
7.5 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Union
import pandas as pd
from pyspark.pandas.plot import (
HistogramPlotBase,
name_like_string,
PandasOnSparkPlotAccessor,
BoxPlotBase,
KdePlotBase,
)
if TYPE_CHECKING:
import pyspark.pandas as ps # noqa: F401 (SPARK-34943)
def plot_pandas_on_spark(data: Union["ps.DataFrame", "ps.Series"], kind: str, **kwargs):
import plotly
# pandas-on-Spark specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
return plot_histogram(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)
# Other plots.
return plotly.plot(PandasOnSparkPlotAccessor.pandas_plot_data_map[kind](data), kind, **kwargs)
def plot_pie(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
from plotly import express
data = PandasOnSparkPlotAccessor.pandas_plot_data_map["pie"](data)
if isinstance(data, pd.Series):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
values = kwargs.pop("y", None)
default_names = None
if values is not None:
default_names = data.index
return express.pie(
data,
values=kwargs.pop("values", values),
names=kwargs.pop("names", default_names),
**kwargs,
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps
bins = kwargs.get("bins", 10)
y = kwargs.get("y")
if y and isinstance(data, ps.DataFrame):
# Note that the results here are matched with matplotlib. x and y
# handling is different from pandas' plotly output.
data = data[y]
psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = HistogramPlotBase.compute_hist(psdf, bins)
prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
text_bins = []
for b in bins[1:]:
norm_b = float("%.9f" % b)
text_bins.append("[%s, %s)" % (prev, norm_b))
prev = norm_b
text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket.
bins = 0.5 * (bins[:-1] + bins[1:])
output_series = list(output_series)
bars = []
for series in output_series:
bars.append(
go.Bar(
x=bins,
y=series,
name=name_like_string(series.name),
text=text_bins,
hovertemplate=(
"variable=" + name_like_string(series.name) + "<br>value=%{text}<br>count=%{y}"
),
)
)
fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig
def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps
if isinstance(data, ps.DataFrame):
raise RuntimeError(
"plotly does not support a box plot with pandas-on-Spark DataFrame. Use Series instead."
)
# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
# third quartiles (?). Looks they use default 1.5.
whis = kwargs.pop("whis", 1.5)
# 'precision' is pandas-on-Spark specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)
# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise ValueError(
"plotly plotting backend does not support 'boxpoints' set to '%s'. "
"Set to 'suspectedoutliers' or False." % boxpoints
)
if notched:
raise ValueError(
"plotly plotting backend does not support 'notched' set to '%s'. "
"Set to False." % notched
)
colname = name_like_string(data.name)
spark_column_name = data._internal.spark_column_name_for(data._column_label)
# Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)
# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
fliers = None
if boxpoints:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])
fliers = [fliers] if len(fliers) > 0 else None
fig = go.Figure()
fig.add_trace(
go.Box(
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[whiskers[0]],
upperfence=[whiskers[1]],
y=fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname
fig["layout"]["yaxis"]["title"] = "value"
return fig
def plot_kde(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
from plotly import express
import pyspark.pandas as ps
if isinstance(data, ps.DataFrame) and "color" not in kwargs:
kwargs["color"] = "names"
psdf = KdePlotBase.prepare_kde_data(data)
sdf = psdf._internal.spark_frame
data_columns = psdf._internal.data_spark_columns
ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None))
bw_method = kwargs.pop("bw_method", None)
pdfs = []
for label in psdf._internal.column_labels:
pdfs.append(
pd.DataFrame(
{
"Density": KdePlotBase.compute_kde(
sdf.select(psdf._internal.spark_column_for(label)),
ind=ind,
bw_method=bw_method,
),
"names": name_like_string(label),
"index": ind,
}
)
)
pdf = pd.concat(pdfs)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig