[SPARK-34890][PYTHON] Port/integrate Koalas main codes into PySpark

### What changes were proposed in this pull request?

As a first step of [SPARK-34849](https://issues.apache.org/jira/browse/SPARK-34849), this PR proposes porting the Koalas main code into PySpark.

This PR contains minimal changes to the existing Koalas code as follows:
1. `databricks.koalas` -> `pyspark.pandas`
2. `from databricks import koalas as ks` -> `from pyspark import pandas as pp`
3. `ks.xxx -> pp.xxx`

Other than them:
1. Added a line to `python/mypy.ini` in order to ignore the mypy test. See related issue at [SPARK-34941](https://issues.apache.org/jira/browse/SPARK-34941).
2. Added a comment to several lines in several files to ignore the flake8 F401. See related issue at [SPARK-34943](https://issues.apache.org/jira/browse/SPARK-34943).

When this PR is merged, all the features that were previously used in [Koalas](https://github.com/databricks/koalas) will be available in PySpark as well.

Users can access to the pandas API in PySpark as below:

```python
>>> from pyspark import pandas as pp
>>> ppdf = pp.DataFrame({"A": [1, 2, 3], "B": [15, 20, 25]})
>>> ppdf
   A   B
0  1  15
1  2  20
2  3  25
```

The existing "options and settings" in Koalas are also available in the same way:

```python
>>> from pyspark.pandas.config import set_option, reset_option, get_option
>>> ppser1 = pp.Series([1, 2, 3])
>>> ppser2 = pp.Series([3, 4, 5])
>>> ppser1 + ppser2
Traceback (most recent call last):
...
ValueError: Cannot combine the series or dataframe because it comes from a different dataframe. In order to allow this operation, enable 'compute.ops_on_diff_frames' option.

>>> set_option("compute.ops_on_diff_frames", True)
>>> ppser1 + ppser2
0    4
1    6
2    8
dtype: int64
```

Please also refer to the [API Reference](https://koalas.readthedocs.io/en/latest/reference/index.html) and [Options and Settings](https://koalas.readthedocs.io/en/latest/user_guide/options.html) for more detail.

**NOTE** that this PR intentionally ports the main codes of Koalas first almost as are with minimal changes because:
- Koalas project is fairly large. Making some changes together for PySpark will make it difficult to review the individual change.
    Koalas dev includes multiple Spark committers who will review. By doing this, the committers will be able to more easily and effectively review and drive the development.
- Koalas tests and documentation require major changes to make it look great together with PySpark whereas main codes do not require.
- We lately froze the Koalas codebase, and plan to work together on the initial porting. By porting the main codes first as are, it unblocks the Koalas dev to work on other items in parallel.

I promise and will make sure on:
- Rename Koalas to PySpark pandas APIs and/or pandas-on-Spark accordingly in documentation, and the docstrings and comments in the main codes.
- Triage APIs to remove that don’t make sense when Koalas is in PySpark

The documentation changes will be tracked in [SPARK-34885](https://issues.apache.org/jira/browse/SPARK-34885), the test code changes will be tracked in [SPARK-34886](https://issues.apache.org/jira/browse/SPARK-34886).

### Why are the changes needed?

Please refer to:
- [[DISCUSS] Support pandas API layer on PySpark](http://apache-spark-developers-list.1001551.n3.nabble.com/DISCUSS-Support-pandas-API-layer-on-PySpark-td30945.html)
- [[VOTE] SPIP: Support pandas API layer on PySpark](http://apache-spark-developers-list.1001551.n3.nabble.com/VOTE-SPIP-Support-pandas-API-layer-on-PySpark-td30996.html)

### Does this PR introduce _any_ user-facing change?

Yes, now users can use the pandas APIs on Spark

### How was this patch tested?

Manually tested for exposed major APIs and options as described above.

### Koalas contributors

Koalas would not have been possible without the following contributors:

ueshin
HyukjinKwon
rxin
xinrong-databricks
RainFung
charlesdong1991
harupy
floscha
beobest2
thunterdb
garawalid
LucasG0
shril
deepyaman
gioa
fwani
90jam
thoo
AbdealiJK
abishekganesh72
gliptak
DumbMachine
dvgodoy
stbof
nitlev
hjoo
gatorsmile
tomspur
icexelloss
awdavidson
guyao
akhilputhiry
scook12
patryk-oleniuk
tracek
dennyglee
athena15
gstaubli
WeichenXu123
hsubbaraj
lfdversluis
ktksq
shengjh
margaret-databricks
LSturtew
sllynn
manuzhang
jijosg
sadikovi

Closes #32036 from itholic/SPARK-34890.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
itholic 2021-04-06 12:42:39 +09:00 committed by HyukjinKwon
parent 7cffacef18
commit caf04f9b77
51 changed files with 51523 additions and 0 deletions

View file

@ -17,6 +17,7 @@
#
# define test binaries + versions
FLAKE8_BUILD="flake8"
# TODO(SPARK-34943): minimum version should be 3.8+
MINIMUM_FLAKE8="3.5.0"
MYPY_BUILD="mypy"
PYCODESTYLE_BUILD="pycodestyle"

View file

@ -126,3 +126,7 @@ ignore_missing_imports = True
[mypy-psutil.*]
ignore_missing_imports = True
# TODO(SPARK-34941): Enable mypy for pandas-on-Spark
[mypy-pyspark.pandas.*]
ignore_errors = True

View file

@ -0,0 +1,209 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from distutils.version import LooseVersion
from pyspark.pandas.version import __version__ # noqa: F401
def assert_python_version():
import warnings
major = 3
minor = 5
deprecated_version = (major, minor)
min_supported_version = (major, minor + 1)
if sys.version_info[:2] <= deprecated_version:
warnings.warn(
"Koalas support for Python {dep_ver} is deprecated and will be dropped in "
"the future release. At that point, existing Python {dep_ver} workflows "
"that use Koalas will continue to work without modification, but Python {dep_ver} "
"users will no longer get access to the latest Koalas features and bugfixes. "
"We recommend that you upgrade to Python {min_ver} or newer.".format(
dep_ver=".".join(map(str, deprecated_version)),
min_ver=".".join(map(str, min_supported_version)),
),
FutureWarning,
)
def assert_pyspark_version():
import logging
try:
import pyspark
except ImportError:
raise ImportError(
"Unable to import pyspark - consider doing a pip install with [spark] "
"extra to install pyspark with pip"
)
else:
pyspark_ver = getattr(pyspark, "__version__")
if pyspark_ver is None or LooseVersion(pyspark_ver) < LooseVersion("2.4"):
logging.warning(
'Found pyspark version "{}" installed. pyspark>=2.4.0 is recommended.'.format(
pyspark_ver if pyspark_ver is not None else "<unknown version>"
)
)
assert_python_version()
assert_pyspark_version()
import pyspark
import pyarrow
if LooseVersion(pyspark.__version__) < LooseVersion("3.0"):
if (
LooseVersion(pyarrow.__version__) >= LooseVersion("0.15")
and "ARROW_PRE_0_15_IPC_FORMAT" not in os.environ
):
import logging
logging.warning(
"'ARROW_PRE_0_15_IPC_FORMAT' environment variable was not set. It is required to "
"set this environment variable to '1' in both driver and executor sides if you use "
"pyarrow>=0.15 and pyspark<3.0. "
"Koalas will set it for you but it does not work if there is a Spark context already "
"launched."
)
# This is required to support PyArrow 0.15 in PySpark versions lower than 3.0.
# See SPARK-29367.
os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
elif "ARROW_PRE_0_15_IPC_FORMAT" in os.environ:
raise RuntimeError(
"Please explicitly unset 'ARROW_PRE_0_15_IPC_FORMAT' environment variable in both "
"driver and executor sides. It is required to set this environment variable only "
"when you use pyarrow>=0.15 and pyspark<3.0."
)
if (
LooseVersion(pyarrow.__version__) >= LooseVersion("2.0.0")
and "PYARROW_IGNORE_TIMEZONE" not in os.environ
):
import logging
logging.warning(
"'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to "
"set this environment variable to '1' in both driver and executor sides if you use "
"pyarrow>=2.0.0. "
"Koalas will set it for you but it does not work if there is a Spark context already "
"launched."
)
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.indexes.category import CategoricalIndex
from pyspark.pandas.indexes.datetimes import DatetimeIndex
from pyspark.pandas.indexes.multi import MultiIndex
from pyspark.pandas.indexes.numeric import Float64Index, Int64Index
from pyspark.pandas.series import Series
from pyspark.pandas.groupby import NamedAgg
__all__ = [ # noqa: F405
"read_csv",
"read_parquet",
"to_datetime",
"date_range",
"from_pandas",
"get_dummies",
"DataFrame",
"Series",
"Index",
"MultiIndex",
"Int64Index",
"Float64Index",
"CategoricalIndex",
"DatetimeIndex",
"sql",
"range",
"concat",
"melt",
"get_option",
"set_option",
"reset_option",
"read_sql_table",
"read_sql_query",
"read_sql",
"options",
"option_context",
"NamedAgg",
]
def _auto_patch_spark():
import os
import logging
# Attach a usage logger.
logger_module = os.getenv("KOALAS_USAGE_LOGGER", "")
if logger_module != "":
try:
from pyspark.pandas import usage_logging
usage_logging.attach(logger_module)
except Exception as e:
logger = logging.getLogger("pyspark.pandas.usage_logger")
logger.warning(
"Tried to attach usage logger `{}`, but an exception was raised: {}".format(
logger_module, str(e)
)
)
# Autopatching is on by default.
x = os.getenv("SPARK_KOALAS_AUTOPATCH", "true")
if x.lower() in ("true", "1", "enabled"):
logger = logging.getLogger("spark")
logger.info(
"Patching spark automatically. You can disable it by setting "
"SPARK_KOALAS_AUTOPATCH=false in your environment"
)
from pyspark.sql import dataframe as df
df.DataFrame.to_koalas = DataFrame.to_koalas
def _auto_patch_pandas():
import pandas as pd
# In order to use it in test cases.
global _frame_has_class_getitem
global _series_has_class_getitem
_frame_has_class_getitem = hasattr(pd.DataFrame, "__class_getitem__")
_series_has_class_getitem = hasattr(pd.Series, "__class_getitem__")
if sys.version_info >= (3, 7):
# Just in case pandas implements '__class_getitem__' later.
if not _frame_has_class_getitem:
pd.DataFrame.__class_getitem__ = lambda params: DataFrame.__class_getitem__(params)
if not _series_has_class_getitem:
pd.Series.__class_getitem__ = lambda params: Series.__class_getitem__(params)
_auto_patch_spark()
_auto_patch_pandas()
# Import after the usage logger is attached.
from pyspark.pandas.config import get_option, options, option_context, reset_option, set_option
from pyspark.pandas.namespace import * # F405
from pyspark.pandas.sql import sql

View file

@ -0,0 +1,930 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Koalas specific features.
"""
import inspect
from distutils.version import LooseVersion
from typing import Any, Optional, Tuple, Union, TYPE_CHECKING, cast
import types
import numpy as np # noqa: F401
import pandas as pd
import pyspark
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructField, StructType
from pyspark.pandas.internal import (
InternalFrame,
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_SERIES_NAME,
)
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.utils import (
is_name_like_value,
is_name_like_tuple,
name_like_string,
scol_for,
verify_temp_column_name,
)
if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
class KoalasFrameMethods(object):
""" Koalas specific features for DataFrame. """
def __init__(self, frame: "DataFrame"):
self._kdf = frame
def attach_id_column(self, id_type: str, column: Union[Any, Tuple]) -> "DataFrame":
"""
Attach a column to be used as identifier of rows similar to the default index.
See also `Default Index type
<https://koalas.readthedocs.io/en/latest/user_guide/options.html#default-index-type>`_.
Parameters
----------
id_type : string
The id type.
- 'sequence' : a sequence that increases one by one.
.. note:: this uses Spark's Window without specifying partition specification.
This leads to move all data into single partition in single machine and
could cause serious performance degradation.
Avoid this method against very large dataset.
- 'distributed-sequence' : a sequence that increases one by one,
by group-by and group-map approach in a distributed manner.
- 'distributed' : a monotonically increasing sequence simply by using PySparks
monotonically_increasing_id function in a fully distributed manner.
column : string or tuple of string
The column name.
Returns
-------
DataFrame
The DataFrame attached the column.
Examples
--------
>>> df = pp.DataFrame({"x": ['a', 'b', 'c']})
>>> df.koalas.attach_id_column(id_type="sequence", column="id")
x id
0 a 0
1 b 1
2 c 2
>>> df.koalas.attach_id_column(id_type="distributed-sequence", column=0)
x 0
0 a 0
1 b 1
2 c 2
>>> df.koalas.attach_id_column(id_type="distributed", column=0.0)
... # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
x 0.0
0 a ...
1 b ...
2 c ...
For multi-index columns:
>>> df = pp.DataFrame({("x", "y"): ['a', 'b', 'c']})
>>> df.koalas.attach_id_column(id_type="sequence", column=("id-x", "id-y"))
x id-x
y id-y
0 a 0
1 b 1
2 c 2
>>> df.koalas.attach_id_column(id_type="distributed-sequence", column=(0, 1.0))
x 0
y 1.0
0 a 0
1 b 1
2 c 2
"""
from pyspark.pandas.frame import DataFrame
if id_type == "sequence":
attach_func = InternalFrame.attach_sequence_column
elif id_type == "distributed-sequence":
attach_func = InternalFrame.attach_distributed_sequence_column
elif id_type == "distributed":
attach_func = InternalFrame.attach_distributed_column
else:
raise ValueError(
"id_type should be one of 'sequence', 'distributed-sequence' and 'distributed'"
)
assert is_name_like_value(column, allow_none=False), column
if not is_name_like_tuple(column):
column = (column,)
internal = self._kdf._internal
if len(column) != internal.column_labels_level:
raise ValueError(
"The given column `{}` must be the same length as the existing columns.".format(
column
)
)
elif column in internal.column_labels:
raise ValueError(
"The given column `{}` already exists.".format(name_like_string(column))
)
# Make sure the underlying Spark column names are the form of
# `name_like_string(column_label)`.
sdf = internal.spark_frame.select(
[
scol.alias(SPARK_INDEX_NAME_FORMAT(i))
for i, scol in enumerate(internal.index_spark_columns)
]
+ [
scol.alias(name_like_string(label))
for scol, label in zip(internal.data_spark_columns, internal.column_labels)
]
)
sdf = attach_func(sdf, name_like_string(column))
return DataFrame(
InternalFrame(
spark_frame=sdf,
index_spark_columns=[
scol_for(sdf, SPARK_INDEX_NAME_FORMAT(i)) for i in range(internal.index_level)
],
index_names=internal.index_names,
index_dtypes=internal.index_dtypes,
column_labels=internal.column_labels + [column],
data_spark_columns=(
[scol_for(sdf, name_like_string(label)) for label in internal.column_labels]
+ [scol_for(sdf, name_like_string(column))]
),
data_dtypes=(internal.data_dtypes + [None]),
column_label_names=internal.column_label_names,
).resolved_copy
)
def apply_batch(self, func, args=(), **kwds) -> "DataFrame":
"""
Apply a function that takes pandas DataFrame and outputs pandas DataFrame. The pandas
DataFrame given to the function is of a batch used internally.
See also `Transform and apply a function
<https://koalas.readthedocs.io/en/latest/user_guide/transform_apply.html>`_.
.. note:: the `func` is unable to access to the whole input frame. Koalas internally
splits the input series into multiple batches and calls `func` with each batch multiple
times. Therefore, operations such as global aggregations are impossible. See the example
below.
>>> # This case does not return the length of whole frame but of the batch internally
... # used.
... def length(pdf) -> pp.DataFrame[int]:
... return pd.DataFrame([len(pdf)])
...
>>> df = pp.DataFrame({'A': range(1000)})
>>> df.koalas.apply_batch(length) # doctest: +SKIP
c0
0 83
1 83
2 83
...
10 83
11 83
.. note:: this API executes the function once to infer the type which is
potentially expensive, for instance, when the dataset is created after
aggregations or sorting.
To avoid this, specify return type in ``func``, for instance, as below:
>>> def plus_one(x) -> pp.DataFrame[float, float]:
... return x + 1
If the return type is specified, the output column names become
`c0, c1, c2 ... cn`. These names are positionally mapped to the returned
DataFrame in ``func``.
To specify the column names, you can assign them in a pandas friendly style as below:
>>> def plus_one(x) -> pp.DataFrame["a": float, "b": float]:
... return x + 1
>>> pdf = pd.DataFrame({'a': [1, 2, 3], 'b': [3, 4, 5]})
>>> def plus_one(x) -> pp.DataFrame[zip(pdf.dtypes, pdf.columns)]:
... return x + 1
When the given function has the return type annotated, the original index of the
DataFrame will be lost and a default index will be attached to the result DataFrame.
Please be careful about configuring the default index. See also `Default Index Type
<https://koalas.readthedocs.io/en/latest/user_guide/options.html#default-index-type>`_.
Parameters
----------
func : function
Function to apply to each pandas frame.
args : tuple
Positional arguments to pass to `func` in addition to the
array/series.
**kwds
Additional keyword arguments to pass as keywords arguments to
`func`.
Returns
-------
DataFrame
See Also
--------
DataFrame.apply: For row/columnwise operations.
DataFrame.applymap: For elementwise operations.
DataFrame.aggregate: Only perform aggregating type operations.
DataFrame.transform: Only perform transforming type operations.
Series.koalas.transform_batch: transform the search as each pandas chunpp.
Examples
--------
>>> df = pp.DataFrame([(1, 2), (3, 4), (5, 6)], columns=['A', 'B'])
>>> df
A B
0 1 2
1 3 4
2 5 6
>>> def query_func(pdf) -> pp.DataFrame[int, int]:
... return pdf.query('A == 1')
>>> df.koalas.apply_batch(query_func)
c0 c1
0 1 2
>>> def query_func(pdf) -> pp.DataFrame["A": int, "B": int]:
... return pdf.query('A == 1')
>>> df.koalas.apply_batch(query_func)
A B
0 1 2
You can also omit the type hints so Koalas infers the return schema as below:
>>> df.koalas.apply_batch(lambda pdf: pdf.query('A == 1'))
A B
0 1 2
You can also specify extra arguments.
>>> def calculation(pdf, y, z) -> pp.DataFrame[int, int]:
... return pdf ** y + z
>>> df.koalas.apply_batch(calculation, args=(10,), z=20)
c0 c1
0 21 1044
1 59069 1048596
2 9765645 60466196
You can also use ``np.ufunc`` and built-in functions as input.
>>> df.koalas.apply_batch(np.add, args=(10,))
A B
0 11 12
1 13 14
2 15 16
>>> (df * -1).koalas.apply_batch(abs)
A B
0 1 2
1 3 4
2 5 6
"""
# TODO: codes here partially duplicate `DataFrame.apply`. Can we deduplicate?
from pyspark.pandas.groupby import GroupBy
from pyspark.pandas.frame import DataFrame
from pyspark import pandas as pp
if not isinstance(func, types.FunctionType):
assert callable(func), "the first argument should be a callable function."
f = func
func = lambda *args, **kwargs: f(*args, **kwargs)
spec = inspect.getfullargspec(func)
return_sig = spec.annotations.get("return", None)
should_infer_schema = return_sig is None
should_use_map_in_pandas = LooseVersion(pyspark.__version__) >= "3.0"
original_func = func
func = lambda o: original_func(o, *args, **kwds)
self_applied = DataFrame(self._kdf._internal.resolved_copy) # type: DataFrame
if should_infer_schema:
# Here we execute with the first 1000 to get the return type.
# If the records were less than 1000, it uses pandas API directly for a shortcut.
limit = pp.get_option("compute.shortcut_limit")
pdf = self_applied.head(limit + 1)._to_internal_pandas()
applied = func(pdf)
if not isinstance(applied, pd.DataFrame):
raise ValueError(
"The given function should return a frame; however, "
"the return type was %s." % type(applied)
)
kdf = pp.DataFrame(applied) # type: DataFrame
if len(pdf) <= limit:
return kdf
return_schema = force_decimal_precision_scale(
as_nullable_spark_type(kdf._internal.to_internal_spark_frame.schema)
)
if should_use_map_in_pandas:
output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=True
)
sdf = self_applied._internal.to_internal_spark_frame.mapInPandas(
lambda iterator: map(output_func, iterator), schema=return_schema
)
else:
sdf = GroupBy._spark_group_map_apply(
self_applied, func, (F.spark_partition_id(),), return_schema, retain_index=True
)
# If schema is inferred, we can restore indexes too.
internal = kdf._internal.with_new_sdf(sdf)
else:
return_type = infer_return_type(original_func)
is_return_dataframe = isinstance(return_type, DataFrameType)
if not is_return_dataframe:
raise TypeError(
"The given function should specify a frame as its type "
"hints; however, the return type was %s." % return_sig
)
return_schema = cast(DataFrameType, return_type).spark_type
if should_use_map_in_pandas:
output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=False
)
sdf = self_applied._internal.to_internal_spark_frame.mapInPandas(
lambda iterator: map(output_func, iterator), schema=return_schema
)
else:
sdf = GroupBy._spark_group_map_apply(
self_applied, func, (F.spark_partition_id(),), return_schema, retain_index=False
)
# Otherwise, it loses index.
internal = InternalFrame(
spark_frame=sdf,
index_spark_columns=None,
data_dtypes=cast(DataFrameType, return_type).dtypes,
)
return DataFrame(internal)
def transform_batch(self, func, *args, **kwargs) -> Union["DataFrame", "Series"]:
"""
Transform chunks with a function that takes pandas DataFrame and outputs pandas DataFrame.
The pandas DataFrame given to the function is of a batch used internally. The length of
each input and output should be the same.
See also `Transform and apply a function
<https://koalas.readthedocs.io/en/latest/user_guide/transform_apply.html>`_.
.. note:: the `func` is unable to access to the whole input frame. Koalas internally
splits the input series into multiple batches and calls `func` with each batch multiple
times. Therefore, operations such as global aggregations are impossible. See the example
below.
>>> # This case does not return the length of whole frame but of the batch internally
... # used.
... def length(pdf) -> pp.DataFrame[int]:
... return pd.DataFrame([len(pdf)] * len(pdf))
...
>>> df = pp.DataFrame({'A': range(1000)})
>>> df.koalas.transform_batch(length) # doctest: +SKIP
c0
0 83
1 83
2 83
...
.. note:: this API executes the function once to infer the type which is
potentially expensive, for instance, when the dataset is created after
aggregations or sorting.
To avoid this, specify return type in ``func``, for instance, as below:
>>> def plus_one(x) -> pp.DataFrame[float, float]:
... return x + 1
If the return type is specified, the output column names become
`c0, c1, c2 ... cn`. These names are positionally mapped to the returned
DataFrame in ``func``.
To specify the column names, you can assign them in a pandas friendly style as below:
>>> def plus_one(x) -> pp.DataFrame['a': float, 'b': float]:
... return x + 1
>>> pdf = pd.DataFrame({'a': [1, 2, 3], 'b': [3, 4, 5]})
>>> def plus_one(x) -> pp.DataFrame[zip(pdf.dtypes, pdf.columns)]:
... return x + 1
When the given function returns DataFrame and has the return type annotated, the
original index of the DataFrame will be lost and then a default index will be attached
to the result. Please be careful about configuring the default index. See also
`Default Index Type
<https://koalas.readthedocs.io/en/latest/user_guide/options.html#default-index-type>`_.
Parameters
----------
func : function
Function to transform each pandas frame.
*args
Positional arguments to pass to func.
**kwargs
Keyword arguments to pass to func.
Returns
-------
DataFrame or Series
See Also
--------
DataFrame.koalas.apply_batch: For row/columnwise operations.
Series.koalas.transform_batch: transform the search as each pandas chunpp.
Examples
--------
>>> df = pp.DataFrame([(1, 2), (3, 4), (5, 6)], columns=['A', 'B'])
>>> df
A B
0 1 2
1 3 4
2 5 6
>>> def plus_one_func(pdf) -> pp.DataFrame[int, int]:
... return pdf + 1
>>> df.koalas.transform_batch(plus_one_func)
c0 c1
0 2 3
1 4 5
2 6 7
>>> def plus_one_func(pdf) -> pp.DataFrame['A': int, 'B': int]:
... return pdf + 1
>>> df.koalas.transform_batch(plus_one_func)
A B
0 2 3
1 4 5
2 6 7
>>> def plus_one_func(pdf) -> pp.Series[int]:
... return pdf.B + 1
>>> df.koalas.transform_batch(plus_one_func)
0 3
1 5
2 7
dtype: int64
You can also omit the type hints so Koalas infers the return schema as below:
>>> df.koalas.transform_batch(lambda pdf: pdf + 1)
A B
0 2 3
1 4 5
2 6 7
>>> (df * -1).koalas.transform_batch(abs)
A B
0 1 2
1 3 4
2 5 6
Note that you should not transform the index. The index information will not change.
>>> df.koalas.transform_batch(lambda pdf: pdf.B + 1)
0 3
1 5
2 7
Name: B, dtype: int64
You can also specify extra arguments as below.
>>> df.koalas.transform_batch(lambda pdf, a, b, c: pdf.B + a + b + c, 1, 2, c=3)
0 8
1 10
2 12
Name: B, dtype: int64
"""
from pyspark.pandas.groupby import GroupBy
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import first_series
from pyspark import pandas as pp
assert callable(func), "the first argument should be a callable function."
spec = inspect.getfullargspec(func)
return_sig = spec.annotations.get("return", None)
should_infer_schema = return_sig is None
original_func = func
func = lambda o: original_func(o, *args, **kwargs)
names = self._kdf._internal.to_internal_spark_frame.schema.names
should_by_pass = LooseVersion(pyspark.__version__) >= "3.0"
def pandas_concat(series):
# The input can only be a DataFrame for struct from Spark 3.0.
# This works around to make the input as a frame. See SPARK-27240
pdf = pd.concat(series, axis=1)
pdf.columns = names
return pdf
def apply_func(pdf):
return func(pdf).to_frame()
def pandas_extract(pdf, name):
# This is for output to work around a DataFrame for struct
# from Spark 3.0. See SPARK-23836
return pdf[name]
def pandas_series_func(f, by_pass):
ff = f
if by_pass:
return lambda *series: first_series(ff(*series))
else:
return lambda *series: first_series(ff(pandas_concat(series)))
def pandas_frame_func(f, field_name):
ff = f
return lambda *series: pandas_extract(ff(pandas_concat(series)), field_name)
if should_infer_schema:
# Here we execute with the first 1000 to get the return type.
# If the records were less than 1000, it uses pandas API directly for a shortcut.
limit = pp.get_option("compute.shortcut_limit")
pdf = self._kdf.head(limit + 1)._to_internal_pandas()
transformed = func(pdf)
if not isinstance(transformed, (pd.DataFrame, pd.Series)):
raise ValueError(
"The given function should return a frame; however, "
"the return type was %s." % type(transformed)
)
if len(transformed) != len(pdf):
raise ValueError("transform_batch cannot produce aggregated results")
kdf_or_kser = pp.from_pandas(transformed)
if isinstance(kdf_or_kser, pp.Series):
kser = cast(pp.Series, kdf_or_kser)
spark_return_type = force_decimal_precision_scale(
as_nullable_spark_type(kser.spark.data_type)
)
return_schema = StructType(
[StructField(SPARK_DEFAULT_SERIES_NAME, spark_return_type)]
)
output_func = GroupBy._make_pandas_df_builder_func(
self._kdf, apply_func, return_schema, retain_index=False
)
pudf = pandas_udf(
pandas_series_func(output_func, should_by_pass),
returnType=spark_return_type,
functionType=PandasUDFType.SCALAR,
)
columns = self._kdf._internal.spark_columns
# TODO: Index will be lost in this case.
internal = self._kdf._internal.copy(
column_labels=kser._internal.column_labels,
data_spark_columns=[
(pudf(F.struct(*columns)) if should_by_pass else pudf(*columns)).alias(
kser._internal.data_spark_column_names[0]
)
],
data_dtypes=kser._internal.data_dtypes,
column_label_names=kser._internal.column_label_names,
)
return first_series(DataFrame(internal))
else:
kdf = cast(DataFrame, kdf_or_kser)
if len(pdf) <= limit:
# only do the short cut when it returns a frame to avoid
# operations on different dataframes in case of series.
return kdf
# Force nullability.
return_schema = force_decimal_precision_scale(
as_nullable_spark_type(kdf._internal.to_internal_spark_frame.schema)
)
self_applied = DataFrame(self._kdf._internal.resolved_copy) # type: DataFrame
output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=True
)
columns = self_applied._internal.spark_columns
if should_by_pass:
pudf = pandas_udf(
output_func, returnType=return_schema, functionType=PandasUDFType.SCALAR
)
temp_struct_column = verify_temp_column_name(
self_applied._internal.spark_frame, "__temp_struct__"
)
applied = pudf(F.struct(*columns)).alias(temp_struct_column)
sdf = self_applied._internal.spark_frame.select(applied)
sdf = sdf.selectExpr("%s.*" % temp_struct_column)
else:
applied = []
for field in return_schema.fields:
applied.append(
pandas_udf(
pandas_frame_func(output_func, field.name),
returnType=field.dataType,
functionType=PandasUDFType.SCALAR,
)(*columns).alias(field.name)
)
sdf = self_applied._internal.spark_frame.select(*applied)
return DataFrame(kdf._internal.with_new_sdf(sdf))
else:
return_type = infer_return_type(original_func)
is_return_series = isinstance(return_type, SeriesType)
is_return_dataframe = isinstance(return_type, DataFrameType)
if not is_return_dataframe and not is_return_series:
raise TypeError(
"The given function should specify a frame or series as its type "
"hints; however, the return type was %s." % return_sig
)
if is_return_series:
spark_return_type = force_decimal_precision_scale(
as_nullable_spark_type(cast(SeriesType, return_type).spark_type)
)
return_schema = StructType(
[StructField(SPARK_DEFAULT_SERIES_NAME, spark_return_type)]
)
output_func = GroupBy._make_pandas_df_builder_func(
self._kdf, apply_func, return_schema, retain_index=False
)
pudf = pandas_udf(
pandas_series_func(output_func, should_by_pass),
returnType=spark_return_type,
functionType=PandasUDFType.SCALAR,
)
columns = self._kdf._internal.spark_columns
internal = self._kdf._internal.copy(
column_labels=[None],
data_spark_columns=[
(pudf(F.struct(*columns)) if should_by_pass else pudf(*columns)).alias(
SPARK_DEFAULT_SERIES_NAME
)
],
data_dtypes=[cast(SeriesType, return_type).dtype],
column_label_names=None,
)
return first_series(DataFrame(internal))
else:
return_schema = cast(DataFrameType, return_type).spark_type
self_applied = DataFrame(self._kdf._internal.resolved_copy)
output_func = GroupBy._make_pandas_df_builder_func(
self_applied, func, return_schema, retain_index=False
)
columns = self_applied._internal.spark_columns
if should_by_pass:
pudf = pandas_udf(
output_func, returnType=return_schema, functionType=PandasUDFType.SCALAR
)
temp_struct_column = verify_temp_column_name(
self_applied._internal.spark_frame, "__temp_struct__"
)
applied = pudf(F.struct(*columns)).alias(temp_struct_column)
sdf = self_applied._internal.spark_frame.select(applied)
sdf = sdf.selectExpr("%s.*" % temp_struct_column)
else:
applied = []
for field in return_schema.fields:
applied.append(
pandas_udf(
pandas_frame_func(output_func, field.name),
returnType=field.dataType,
functionType=PandasUDFType.SCALAR,
)(*columns).alias(field.name)
)
sdf = self_applied._internal.spark_frame.select(*applied)
internal = InternalFrame(
spark_frame=sdf,
index_spark_columns=None,
data_dtypes=cast(DataFrameType, return_type).dtypes,
)
return DataFrame(internal)
class KoalasSeriesMethods(object):
""" Koalas specific features for Series. """
def __init__(self, series: "Series"):
self._kser = series
def transform_batch(self, func, *args, **kwargs) -> "Series":
"""
Transform the data with the function that takes pandas Series and outputs pandas Series.
The pandas Series given to the function is of a batch used internally.
See also `Transform and apply a function
<https://koalas.readthedocs.io/en/latest/user_guide/transform_apply.html>`_.
.. note:: the `func` is unable to access to the whole input series. Koalas internally
splits the input series into multiple batches and calls `func` with each batch multiple
times. Therefore, operations such as global aggregations are impossible. See the example
below.
>>> # This case does not return the length of whole frame but of the batch internally
... # used.
... def length(pser) -> pp.Series[int]:
... return pd.Series([len(pser)] * len(pser))
...
>>> df = pp.DataFrame({'A': range(1000)})
>>> df.A.koalas.transform_batch(length) # doctest: +SKIP
c0
0 83
1 83
2 83
...
.. note:: this API executes the function once to infer the type which is
potentially expensive, for instance, when the dataset is created after
aggregations or sorting.
To avoid this, specify return type in ``func``, for instance, as below:
>>> def plus_one(x) -> pp.Series[int]:
... return x + 1
Parameters
----------
func : function
Function to apply to each pandas frame.
*args
Positional arguments to pass to func.
**kwargs
Keyword arguments to pass to func.
Returns
-------
DataFrame
See Also
--------
DataFrame.koalas.apply_batch : Similar but it takes pandas DataFrame as its internal batch.
Examples
--------
>>> df = pp.DataFrame([(1, 2), (3, 4), (5, 6)], columns=['A', 'B'])
>>> df
A B
0 1 2
1 3 4
2 5 6
>>> def plus_one_func(pser) -> pp.Series[np.int64]:
... return pser + 1
>>> df.A.koalas.transform_batch(plus_one_func)
0 2
1 4
2 6
Name: A, dtype: int64
You can also omit the type hints so Koalas infers the return schema as below:
>>> df.A.koalas.transform_batch(lambda pser: pser + 1)
0 2
1 4
2 6
Name: A, dtype: int64
You can also specify extra arguments.
>>> def plus_one_func(pser, a, b, c=3) -> pp.Series[np.int64]:
... return pser + a + b + c
>>> df.A.koalas.transform_batch(plus_one_func, 1, b=2)
0 7
1 9
2 11
Name: A, dtype: int64
You can also use ``np.ufunc`` and built-in functions as input.
>>> df.A.koalas.transform_batch(np.add, 10)
0 11
1 13
2 15
Name: A, dtype: int64
>>> (df * -1).A.koalas.transform_batch(abs)
0 1
1 3
2 5
Name: A, dtype: int64
"""
assert callable(func), "the first argument should be a callable function."
return_sig = None
try:
spec = inspect.getfullargspec(func)
return_sig = spec.annotations.get("return", None)
except TypeError:
# Falls back to schema inference if it fails to get signature.
pass
return_type = None
if return_sig is not None:
# Extract the signature arguments from this function.
sig_return = infer_return_type(func)
if not isinstance(sig_return, SeriesType):
raise ValueError(
"Expected the return type of this function to be of type column,"
" but found type {}".format(sig_return)
)
return_type = cast(SeriesType, sig_return)
return self._transform_batch(lambda c: func(c, *args, **kwargs), return_type)
def _transform_batch(self, func, return_type: Optional[Union[SeriesType, ScalarType]]):
from pyspark.pandas.groupby import GroupBy
from pyspark.pandas.series import Series, first_series
from pyspark import pandas as pp
if not isinstance(func, types.FunctionType):
f = func
func = lambda *args, **kwargs: f(*args, **kwargs)
if return_type is None:
# TODO: In this case, it avoids the shortcut for now (but only infers schema)
# because it returns a series from a different DataFrame and it has a different
# anchor. We should fix this to allow the shortcut or only allow to infer
# schema.
limit = pp.get_option("compute.shortcut_limit")
pser = self._kser.head(limit + 1)._to_internal_pandas()
transformed = pser.transform(func)
kser = Series(transformed) # type: Series
spark_return_type = force_decimal_precision_scale(
as_nullable_spark_type(kser.spark.data_type)
)
dtype = kser.dtype
else:
spark_return_type = return_type.spark_type
dtype = return_type.dtype
kdf = self._kser.to_frame()
columns = kdf._internal.spark_column_names
def pandas_concat(series):
# The input can only be a DataFrame for struct from Spark 3.0.
# This works around to make the input as a frame. See SPARK-27240
pdf = pd.concat(series, axis=1)
pdf.columns = columns
return pdf
def apply_func(pdf):
return func(first_series(pdf)).to_frame()
return_schema = StructType([StructField(SPARK_DEFAULT_SERIES_NAME, spark_return_type)])
output_func = GroupBy._make_pandas_df_builder_func(
kdf, apply_func, return_schema, retain_index=False
)
pudf = pandas_udf(
lambda *series: first_series(output_func(pandas_concat(series))),
returnType=spark_return_type,
functionType=PandasUDFType.SCALAR,
)
return self._kser._with_new_scol(
scol=pudf(*kdf._internal.spark_columns).alias(
self._kser._internal.spark_column_names[0]
),
dtype=dtype,
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,164 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING
import pandas as pd
from pandas.api.types import CategoricalDtype
if TYPE_CHECKING:
import pyspark.pandas as pp # noqa: F401 (SPARK-34943)
class CategoricalAccessor(object):
"""
Accessor object for categorical properties of the Series values.
Examples
--------
>>> s = pp.Series(list("abbccc"), dtype="category")
>>> s # doctest: +SKIP
0 a
1 b
2 b
3 c
4 c
5 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> s.cat.categories
Index(['a', 'b', 'c'], dtype='object')
>>> s.cat.codes
0 0
1 1
2 1
3 2
4 2
5 2
dtype: int8
"""
def __init__(self, series: "pp.Series"):
if not isinstance(series.dtype, CategoricalDtype):
raise ValueError("Cannot call CategoricalAccessor on type {}".format(series.dtype))
self._data = series
@property
def categories(self) -> pd.Index:
"""
The categories of this categorical.
Examples
--------
>>> s = pp.Series(list("abbccc"), dtype="category")
>>> s # doctest: +SKIP
0 a
1 b
2 b
3 c
4 c
5 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> s.cat.categories
Index(['a', 'b', 'c'], dtype='object')
"""
return self._data.dtype.categories
@categories.setter
def categories(self, categories) -> None:
raise NotImplementedError()
@property
def ordered(self) -> bool:
"""
Whether the categories have an ordered relationship.
Examples
--------
>>> s = pp.Series(list("abbccc"), dtype="category")
>>> s # doctest: +SKIP
0 a
1 b
2 b
3 c
4 c
5 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> s.cat.ordered
False
"""
return self._data.dtype.ordered
@property
def codes(self) -> "pp.Series":
"""
Return Series of codes as well as the index.
Examples
--------
>>> s = pp.Series(list("abbccc"), dtype="category")
>>> s # doctest: +SKIP
0 a
1 b
2 b
3 c
4 c
5 c
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> s.cat.codes
0 0
1 1
2 1
3 2
4 2
5 2
dtype: int8
"""
return self._data._with_new_scol(self._data.spark.column).rename()
def add_categories(self, new_categories, inplace: bool = False):
raise NotImplementedError()
def as_ordered(self, inplace: bool = False):
raise NotImplementedError()
def as_unordered(self, inplace: bool = False):
raise NotImplementedError()
def remove_categories(self, removals, inplace: bool = False):
raise NotImplementedError()
def remove_unused_categories(self):
raise NotImplementedError()
def rename_categories(self, new_categories, inplace: bool = False):
raise NotImplementedError()
def reorder_categories(self, new_categories, ordered: bool = None, inplace: bool = False):
raise NotImplementedError()
def set_categories(
self, new_categories, ordered: bool = None, rename: bool = False, inplace: bool = False
):
raise NotImplementedError()

View file

@ -0,0 +1,442 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Infrastructure of options for Koalas.
"""
from contextlib import contextmanager
import json
from typing import Union, Any, Tuple, Callable, List, Dict # noqa: F401 (SPARK-34943)
from pyspark._globals import _NoValue, _NoValueType
from pyspark.pandas.utils import default_session
__all__ = ["get_option", "set_option", "reset_option", "options", "option_context"]
class Option:
"""
Option class that defines an option with related properties.
This class holds all information relevant to the one option. Also,
Its instance can validate if the given value is acceptable or not.
It is currently for internal usage only.
Parameters
----------
key: str, keyword-only argument
the option name to use.
doc: str, keyword-only argument
the documentation for the current option.
default: Any, keyword-only argument
default value for this option.
types: Union[Tuple[type, ...], type], keyword-only argument
default is str. It defines the expected types for this option. It is
used with `isinstance` to validate the given value to this option.
check_func: Tuple[Callable[[Any], bool], str], keyword-only argument
default is a function that always returns `True` with a empty string.
It defines:
- a function to check the given value to this option
- the error message to show when this check is failed
When new value is set to this option, this function is called to check
if the given value is valid.
Examples
--------
>>> option = Option(
... key='option.name',
... doc="this is a test option",
... default="default",
... types=(float, int),
... check_func=(lambda v: v > 0, "should be a positive float"))
>>> option.validate('abc') # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: The value for option 'option.name' was <class 'str'>;
however, expected types are [(<class 'float'>, <class 'int'>)].
>>> option.validate(-1.1)
Traceback (most recent call last):
...
ValueError: should be a positive float
>>> option.validate(1.1)
"""
def __init__(
self,
*,
key: str,
doc: str,
default: Any,
types: Union[Tuple[type, ...], type] = str,
check_func: Tuple[Callable[[Any], bool], str] = (lambda v: True, "")
):
self.key = key
self.doc = doc
self.default = default
self.types = types
self.check_func = check_func
def validate(self, v: Any) -> None:
"""
Validate the given value and throw an exception with related information such as key.
"""
if not isinstance(v, self.types):
raise ValueError(
"The value for option '%s' was %s; however, expected types are "
"[%s]." % (self.key, type(v), str(self.types))
)
if not self.check_func[0](v):
raise ValueError(self.check_func[1])
# Available options.
#
# NOTE: if you are fixing or adding an option here, make sure you execute `show_options()` and
# copy & paste the results into show_options 'docs/source/user_guide/options.rst' as well.
# See the examples below:
# >>> from pyspark.pandas.config import show_options
# >>> show_options()
_options = [
Option(
key="display.max_rows",
doc=(
"This sets the maximum number of rows Koalas should output when printing out "
"various output. For example, this value determines the number of rows to be "
"shown at the repr() in a dataframe. Set `None` to unlimit the input length. "
"Default is 1000."
),
default=1000,
types=(int, type(None)),
check_func=(
lambda v: v is None or v >= 0,
"'display.max_rows' should be greater than or equal to 0.",
),
),
Option(
key="compute.max_rows",
doc=(
"'compute.max_rows' sets the limit of the current Koalas DataFrame. Set `None` to "
"unlimit the input length. When the limit is set, it is executed by the shortcut by "
"collecting the data into the driver, and then using the pandas API. If the limit is "
"unset, the operation is executed by PySpark. Default is 1000."
),
default=1000,
types=(int, type(None)),
check_func=(
lambda v: v is None or v >= 0,
"'compute.max_rows' should be greater than or equal to 0.",
),
),
Option(
key="compute.shortcut_limit",
doc=(
"'compute.shortcut_limit' sets the limit for a shortcut. "
"It computes specified number of rows and use its schema. When the dataframe "
"length is larger than this limit, Koalas uses PySpark to compute."
),
default=1000,
types=int,
check_func=(
lambda v: v >= 0,
"'compute.shortcut_limit' should be greater than or equal to 0.",
),
),
Option(
key="compute.ops_on_diff_frames",
doc=(
"This determines whether or not to operate between two different dataframes. "
"For example, 'combine_frames' function internally performs a join operation which "
"can be expensive in general. So, if `compute.ops_on_diff_frames` variable is not "
"True, that method throws an exception."
),
default=False,
types=bool,
),
Option(
key="compute.default_index_type",
doc=("This sets the default index type: sequence, distributed and distributed-sequence."),
default="sequence",
types=str,
check_func=(
lambda v: v in ("sequence", "distributed", "distributed-sequence"),
"Index type should be one of 'sequence', 'distributed', 'distributed-sequence'.",
),
),
Option(
key="compute.ordered_head",
doc=(
"'compute.ordered_head' sets whether or not to operate head with natural ordering. "
"Koalas does not guarantee the row ordering so `head` could return some rows from "
"distributed partitions. If 'compute.ordered_head' is set to True, Koalas performs "
"natural ordering beforehand, but it will cause a performance overhead."
),
default=False,
types=bool,
),
Option(
key="plotting.max_rows",
doc=(
"'plotting.max_rows' sets the visual limit on top-n-based plots such as `plot.bar` "
"and `plot.pie`. If it is set to 1000, the first 1000 data points will be used "
"for plotting. Default is 1000."
),
default=1000,
types=int,
check_func=(
lambda v: v is v >= 0,
"'plotting.max_rows' should be greater than or equal to 0.",
),
),
Option(
key="plotting.sample_ratio",
doc=(
"'plotting.sample_ratio' sets the proportion of data that will be plotted for sample-"
"based plots such as `plot.line` and `plot.area`. "
"This option defaults to 'plotting.max_rows' option."
),
default=None,
types=(float, type(None)),
check_func=(
lambda v: v is None or 1 >= v >= 0,
"'plotting.sample_ratio' should be 1.0 >= value >= 0.0.",
),
),
Option(
key="plotting.backend",
doc=(
"Backend to use for plotting. Default is plotly. "
"Supports any package that has a top-level `.plot` method. "
"Known options are: [matplotlib, plotly]."
),
default="plotly",
types=str,
),
] # type: List[Option]
_options_dict = dict(zip((option.key for option in _options), _options)) # type: Dict[str, Option]
_key_format = "koalas.{}".format
class OptionError(AttributeError, KeyError):
pass
def show_options():
"""
Make a pretty table that can be copied and pasted into public documentation.
This is currently for an internal purpose.
Examples
--------
>>> show_options() # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
================... =======... =====================...
Option Default Description
================... =======... =====================...
display.max_rows 1000 This sets the maximum...
...
================... =======... =====================...
"""
import textwrap
header = ["Option", "Default", "Description"]
row_format = "{:<31} {:<14} {:<53}"
print(row_format.format("=" * 31, "=" * 14, "=" * 53))
print(row_format.format(*header))
print(row_format.format("=" * 31, "=" * 14, "=" * 53))
for option in _options:
doc = textwrap.fill(option.doc, 53)
formatted = "".join([line + "\n" + (" " * 47) for line in doc.split("\n")]).rstrip()
print(row_format.format(option.key, repr(option.default), formatted))
print(row_format.format("=" * 31, "=" * 14, "=" * 53))
def get_option(key: str, default: Union[Any, _NoValueType] = _NoValue) -> Any:
"""
Retrieves the value of the specified option.
Parameters
----------
key : str
The key which should match a single option.
default : object
The default value if the option is not set yet. The value should be JSON serializable.
Returns
-------
result : the value of the option
Raises
------
OptionError : if no such option exists and the default is not provided
"""
_check_option(key)
if default is _NoValue:
default = _options_dict[key].default
_options_dict[key].validate(default)
return json.loads(default_session().conf.get(_key_format(key), default=json.dumps(default)))
def set_option(key: str, value: Any) -> None:
"""
Sets the value of the specified option.
Parameters
----------
key : str
The key which should match a single option.
value : object
New value of option. The value should be JSON serializable.
Returns
-------
None
"""
_check_option(key)
_options_dict[key].validate(value)
default_session().conf.set(_key_format(key), json.dumps(value))
def reset_option(key: str) -> None:
"""
Reset one option to their default value.
Pass "all" as argument to reset all options.
Parameters
----------
key : str
If specified only option will be reset.
Returns
-------
None
"""
_check_option(key)
default_session().conf.unset(_key_format(key))
@contextmanager
def option_context(*args):
"""
Context manager to temporarily set options in the `with` statement context.
You need to invoke as ``option_context(pat, val, [(pat, val), ...])``.
Examples
--------
>>> with option_context('display.max_rows', 10, 'compute.max_rows', 5):
... print(get_option('display.max_rows'), get_option('compute.max_rows'))
10 5
>>> print(get_option('display.max_rows'), get_option('compute.max_rows'))
1000 1000
"""
if len(args) == 0 or len(args) % 2 != 0:
raise ValueError("Need to invoke as option_context(pat, val, [(pat, val), ...]).")
opts = dict(zip(args[::2], args[1::2]))
orig_opts = {key: get_option(key) for key in opts}
try:
for key, value in opts.items():
set_option(key, value)
yield
finally:
for key, value in orig_opts.items():
set_option(key, value)
def _check_option(key: str) -> None:
if key not in _options_dict:
raise OptionError(
"No such option: '{}'. Available options are [{}]".format(
key, ", ".join(list(_options_dict.keys()))
)
)
class DictWrapper:
""" provide attribute-style access to a nested dict"""
def __init__(self, d, prefix=""):
object.__setattr__(self, "d", d)
object.__setattr__(self, "prefix", prefix)
def __setattr__(self, key, val):
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
if prefix:
prefix += "."
canonical_key = prefix + key
candidates = [
k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split("."))
]
if len(candidates) == 1 and candidates[0] == canonical_key:
return set_option(canonical_key, val)
else:
raise OptionError(
"No such option: '{}'. Available options are [{}]".format(
key, ", ".join(list(_options_dict.keys()))
)
)
def __getattr__(self, key):
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
if prefix:
prefix += "."
canonical_key = prefix + key
candidates = [
k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split("."))
]
if len(candidates) == 1 and candidates[0] == canonical_key:
return get_option(canonical_key)
elif len(candidates) == 0:
raise OptionError(
"No such option: '{}'. Available options are [{}]".format(
key, ", ".join(list(_options_dict.keys()))
)
)
else:
return DictWrapper(d, canonical_key)
def __dir__(self):
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
if prefix == "":
candidates = d.keys()
offset = 0
else:
candidates = [k for k in d.keys() if all(x in k.split(".") for x in prefix.split("."))]
offset = len(prefix) + 1 # prefix (e.g. "compute.") to trim.
return [c[offset:] for c in candidates]
options = DictWrapper(_options_dict)

View file

@ -0,0 +1,850 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Date/Time related functions on Koalas Series
"""
from typing import TYPE_CHECKING
import numpy as np # noqa: F401 (SPARK-34943)
import pandas as pd # noqa: F401
import pyspark.sql.functions as F
from pyspark.sql.types import DateType, TimestampType, LongType
if TYPE_CHECKING:
import pyspark.pandas as pp # noqa: F401 (SPARK-34943)
class DatetimeMethods(object):
"""Date/Time methods for Koalas Series"""
def __init__(self, series: "pp.Series"):
if not isinstance(series.spark.data_type, (DateType, TimestampType)):
raise ValueError(
"Cannot call DatetimeMethods on type {}".format(series.spark.data_type)
)
self._data = series
# Properties
@property
def date(self) -> "pp.Series":
"""
Returns a Series of python datetime.date objects (namely, the date
part of Timestamps without timezone information).
"""
# TODO: Hit a weird exception
# syntax error in attribute name: `to_date(`start_date`)` with alias
return self._data.spark.transform(F.to_date)
@property
def time(self) -> "pp.Series":
raise NotImplementedError()
@property
def timetz(self) -> "pp.Series":
raise NotImplementedError()
@property
def year(self) -> "pp.Series":
"""
The year of the datetime.
"""
return self._data.spark.transform(lambda c: F.year(c).cast(LongType()))
@property
def month(self) -> "pp.Series":
"""
The month of the timestamp as January = 1 December = 12.
"""
return self._data.spark.transform(lambda c: F.month(c).cast(LongType()))
@property
def day(self) -> "pp.Series":
"""
The days of the datetime.
"""
return self._data.spark.transform(lambda c: F.dayofmonth(c).cast(LongType()))
@property
def hour(self) -> "pp.Series":
"""
The hours of the datetime.
"""
return self._data.spark.transform(lambda c: F.hour(c).cast(LongType()))
@property
def minute(self) -> "pp.Series":
"""
The minutes of the datetime.
"""
return self._data.spark.transform(lambda c: F.minute(c).cast(LongType()))
@property
def second(self) -> "pp.Series":
"""
The seconds of the datetime.
"""
return self._data.spark.transform(lambda c: F.second(c).cast(LongType()))
@property
def microsecond(self) -> "pp.Series":
"""
The microseconds of the datetime.
"""
def pandas_microsecond(s) -> "pp.Series[np.int64]":
return s.dt.microsecond
return self._data.koalas.transform_batch(pandas_microsecond)
@property
def nanosecond(self) -> "pp.Series":
raise NotImplementedError()
@property
def week(self) -> "pp.Series":
"""
The week ordinal of the year.
"""
return self._data.spark.transform(lambda c: F.weekofyear(c).cast(LongType()))
@property
def weekofyear(self) -> "pp.Series":
return self.week
weekofyear.__doc__ = week.__doc__
@property
def dayofweek(self) -> "pp.Series":
"""
The day of the week with Monday=0, Sunday=6.
Return the day of the week. It is assumed the week starts on
Monday, which is denoted by 0 and ends on Sunday which is denoted
by 6. This method is available on both Series with datetime
values (using the `dt` accessor).
Returns
-------
Series
Containing integers indicating the day number.
See Also
--------
Series.dt.dayofweek : Alias.
Series.dt.weekday : Alias.
Series.dt.day_name : Returns the name of the day of the week.
Examples
--------
>>> s = pp.from_pandas(pd.date_range('2016-12-31', '2017-01-08', freq='D').to_series())
>>> s.dt.dayofweek
2016-12-31 5
2017-01-01 6
2017-01-02 0
2017-01-03 1
2017-01-04 2
2017-01-05 3
2017-01-06 4
2017-01-07 5
2017-01-08 6
dtype: int64
"""
def pandas_dayofweek(s) -> "pp.Series[np.int64]":
return s.dt.dayofweek
return self._data.koalas.transform_batch(pandas_dayofweek)
@property
def weekday(self) -> "pp.Series":
return self.dayofweek
weekday.__doc__ = dayofweek.__doc__
@property
def dayofyear(self) -> "pp.Series":
"""
The ordinal day of the year.
"""
def pandas_dayofyear(s) -> "pp.Series[np.int64]":
return s.dt.dayofyear
return self._data.koalas.transform_batch(pandas_dayofyear)
@property
def quarter(self) -> "pp.Series":
"""
The quarter of the date.
"""
def pandas_quarter(s) -> "pp.Series[np.int64]":
return s.dt.quarter
return self._data.koalas.transform_batch(pandas_quarter)
@property
def is_month_start(self) -> "pp.Series":
"""
Indicates whether the date is the first day of the month.
Returns
-------
Series
For Series, returns a Series with boolean values.
See Also
--------
is_month_end : Return a boolean indicating whether the date
is the last day of the month.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> s = pp.Series(pd.date_range("2018-02-27", periods=3))
>>> s
0 2018-02-27
1 2018-02-28
2 2018-03-01
dtype: datetime64[ns]
>>> s.dt.is_month_start
0 False
1 False
2 True
dtype: bool
"""
def pandas_is_month_start(s) -> "pp.Series[bool]":
return s.dt.is_month_start
return self._data.koalas.transform_batch(pandas_is_month_start)
@property
def is_month_end(self) -> "pp.Series":
"""
Indicates whether the date is the last day of the month.
Returns
-------
Series
For Series, returns a Series with boolean values.
See Also
--------
is_month_start : Return a boolean indicating whether the date
is the first day of the month.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> s = pp.Series(pd.date_range("2018-02-27", periods=3))
>>> s
0 2018-02-27
1 2018-02-28
2 2018-03-01
dtype: datetime64[ns]
>>> s.dt.is_month_end
0 False
1 True
2 False
dtype: bool
"""
def pandas_is_month_end(s) -> "pp.Series[bool]":
return s.dt.is_month_end
return self._data.koalas.transform_batch(pandas_is_month_end)
@property
def is_quarter_start(self) -> "pp.Series":
"""
Indicator for whether the date is the first day of a quarter.
Returns
-------
is_quarter_start : Series
The same type as the original data with boolean values. Series will
have the same name and index.
See Also
--------
quarter : Return the quarter of the date.
is_quarter_end : Similar property for indicating the quarter start.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> df = pp.DataFrame({'dates': pd.date_range("2017-03-30",
... periods=4)})
>>> df
dates
0 2017-03-30
1 2017-03-31
2 2017-04-01
3 2017-04-02
>>> df.dates.dt.quarter
0 1
1 1
2 2
3 2
Name: dates, dtype: int64
>>> df.dates.dt.is_quarter_start
0 False
1 False
2 True
3 False
Name: dates, dtype: bool
"""
def pandas_is_quarter_start(s) -> "pp.Series[bool]":
return s.dt.is_quarter_start
return self._data.koalas.transform_batch(pandas_is_quarter_start)
@property
def is_quarter_end(self) -> "pp.Series":
"""
Indicator for whether the date is the last day of a quarter.
Returns
-------
is_quarter_end : Series
The same type as the original data with boolean values. Series will
have the same name and index.
See Also
--------
quarter : Return the quarter of the date.
is_quarter_start : Similar property indicating the quarter start.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> df = pp.DataFrame({'dates': pd.date_range("2017-03-30",
... periods=4)})
>>> df
dates
0 2017-03-30
1 2017-03-31
2 2017-04-01
3 2017-04-02
>>> df.dates.dt.quarter
0 1
1 1
2 2
3 2
Name: dates, dtype: int64
>>> df.dates.dt.is_quarter_start
0 False
1 False
2 True
3 False
Name: dates, dtype: bool
"""
def pandas_is_quarter_end(s) -> "pp.Series[bool]":
return s.dt.is_quarter_end
return self._data.koalas.transform_batch(pandas_is_quarter_end)
@property
def is_year_start(self) -> "pp.Series":
"""
Indicate whether the date is the first day of a year.
Returns
-------
Series
The same type as the original data with boolean values. Series will
have the same name and index.
See Also
--------
is_year_end : Similar property indicating the last day of the year.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> dates = pp.Series(pd.date_range("2017-12-30", periods=3))
>>> dates
0 2017-12-30
1 2017-12-31
2 2018-01-01
dtype: datetime64[ns]
>>> dates.dt.is_year_start
0 False
1 False
2 True
dtype: bool
"""
def pandas_is_year_start(s) -> "pp.Series[bool]":
return s.dt.is_year_start
return self._data.koalas.transform_batch(pandas_is_year_start)
@property
def is_year_end(self) -> "pp.Series":
"""
Indicate whether the date is the last day of the year.
Returns
-------
Series
The same type as the original data with boolean values. Series will
have the same name and index.
See Also
--------
is_year_start : Similar property indicating the start of the year.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> dates = pp.Series(pd.date_range("2017-12-30", periods=3))
>>> dates
0 2017-12-30
1 2017-12-31
2 2018-01-01
dtype: datetime64[ns]
>>> dates.dt.is_year_end
0 False
1 True
2 False
dtype: bool
"""
def pandas_is_year_end(s) -> "pp.Series[bool]":
return s.dt.is_year_end
return self._data.koalas.transform_batch(pandas_is_year_end)
@property
def is_leap_year(self) -> "pp.Series":
"""
Boolean indicator if the date belongs to a leap year.
A leap year is a year, which has 366 days (instead of 365) including
29th of February as an intercalary day.
Leap years are years which are multiples of four with the exception
of years divisible by 100 but not by 400.
Returns
-------
Series
Booleans indicating if dates belong to a leap year.
Examples
--------
This method is available on Series with datetime values under
the ``.dt`` accessor.
>>> dates_series = pp.Series(pd.date_range("2012-01-01", "2015-01-01", freq="Y"))
>>> dates_series
0 2012-12-31
1 2013-12-31
2 2014-12-31
dtype: datetime64[ns]
>>> dates_series.dt.is_leap_year
0 True
1 False
2 False
dtype: bool
"""
def pandas_is_leap_year(s) -> "pp.Series[bool]":
return s.dt.is_leap_year
return self._data.koalas.transform_batch(pandas_is_leap_year)
@property
def daysinmonth(self) -> "pp.Series":
"""
The number of days in the month.
"""
def pandas_daysinmonth(s) -> "pp.Series[np.int64]":
return s.dt.daysinmonth
return self._data.koalas.transform_batch(pandas_daysinmonth)
@property
def days_in_month(self) -> "pp.Series":
return self.daysinmonth
days_in_month.__doc__ = daysinmonth.__doc__
# Methods
def tz_localize(self, tz) -> "pp.Series":
"""
Localize tz-naive Datetime column to tz-aware Datetime column.
"""
# Neither tz-naive or tz-aware datetime exists in Spark
raise NotImplementedError()
def tz_convert(self, tz) -> "pp.Series":
"""
Convert tz-aware Datetime column from one time zone to another.
"""
# tz-aware datetime doesn't exist in Spark
raise NotImplementedError()
def normalize(self) -> "pp.Series":
"""
Convert times to midnight.
The time component of the date-time is converted to midnight i.e.
00:00:00. This is useful in cases, when the time does not matter.
Length is unaltered. The timezones are unaffected.
This method is available on Series with datetime values under
the ``.dt`` accessor, and directly on Datetime Array.
Returns
-------
Series
The same type as the original data. Series will have the same
name and index.
See Also
--------
floor : Floor the series to the specified freq.
ceil : Ceil the series to the specified freq.
round : Round the series to the specified freq.
Examples
--------
>>> series = pp.Series(pd.Series(pd.date_range('2012-1-1 12:45:31', periods=3, freq='M')))
>>> series.dt.normalize()
0 2012-01-31
1 2012-02-29
2 2012-03-31
dtype: datetime64[ns]
"""
def pandas_normalize(s) -> "pp.Series[np.datetime64]":
return s.dt.normalize()
return self._data.koalas.transform_batch(pandas_normalize)
def strftime(self, date_format) -> "pp.Series":
"""
Convert to a string Series using specified date_format.
Return an series of formatted strings specified by date_format, which
supports the same string format as the python standard library. Details
of the string format can be found in python string format
doc.
Parameters
----------
date_format : str
Date format string (e.g. "%%Y-%%m-%%d").
Returns
-------
Series
Series of formatted strings.
See Also
--------
to_datetime : Convert the given argument to datetime.
normalize : Return series with times to midnight.
round : Round the series to the specified freq.
floor : Floor the series to the specified freq.
Examples
--------
>>> series = pp.Series(pd.date_range(pd.Timestamp("2018-03-10 09:00"),
... periods=3, freq='s'))
>>> series
0 2018-03-10 09:00:00
1 2018-03-10 09:00:01
2 2018-03-10 09:00:02
dtype: datetime64[ns]
>>> series.dt.strftime('%B %d, %Y, %r')
0 March 10, 2018, 09:00:00 AM
1 March 10, 2018, 09:00:01 AM
2 March 10, 2018, 09:00:02 AM
dtype: object
"""
def pandas_strftime(s) -> "pp.Series[str]":
return s.dt.strftime(date_format)
return self._data.koalas.transform_batch(pandas_strftime)
def round(self, freq, *args, **kwargs) -> "pp.Series":
"""
Perform round operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to round the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
nonexistent : 'shift_forward', 'shift_backward, 'NaT', timedelta, default 'raise'
A nonexistent time does not exist in a particular timezone
where clocks moved forward due to DST.
- 'shift_forward' will shift the nonexistent time forward to the
closest existing time
- 'shift_backward' will shift the nonexistent time backward to the
closest existing time
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are
nonexistent times
.. note:: this option only works with pandas 0.24.0+
Returns
-------
Series
a Series with the same index for a Series.
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> series = pp.Series(pd.date_range('1/1/2018 11:59:00', periods=3, freq='min'))
>>> series
0 2018-01-01 11:59:00
1 2018-01-01 12:00:00
2 2018-01-01 12:01:00
dtype: datetime64[ns]
>>> series.dt.round("H")
0 2018-01-01 12:00:00
1 2018-01-01 12:00:00
2 2018-01-01 12:00:00
dtype: datetime64[ns]
"""
def pandas_round(s) -> "pp.Series[np.datetime64]":
return s.dt.round(freq, *args, **kwargs)
return self._data.koalas.transform_batch(pandas_round)
def floor(self, freq, *args, **kwargs) -> "pp.Series":
"""
Perform floor operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to floor the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
nonexistent : 'shift_forward', 'shift_backward, 'NaT', timedelta, default 'raise'
A nonexistent time does not exist in a particular timezone
where clocks moved forward due to DST.
- 'shift_forward' will shift the nonexistent time forward to the
closest existing time
- 'shift_backward' will shift the nonexistent time backward to the
closest existing time
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are
nonexistent times
.. note:: this option only works with pandas 0.24.0+
Returns
-------
Series
a Series with the same index for a Series.
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> series = pp.Series(pd.date_range('1/1/2018 11:59:00', periods=3, freq='min'))
>>> series
0 2018-01-01 11:59:00
1 2018-01-01 12:00:00
2 2018-01-01 12:01:00
dtype: datetime64[ns]
>>> series.dt.floor("H")
0 2018-01-01 11:00:00
1 2018-01-01 12:00:00
2 2018-01-01 12:00:00
dtype: datetime64[ns]
"""
def pandas_floor(s) -> "pp.Series[np.datetime64]":
return s.dt.floor(freq, *args, **kwargs)
return self._data.koalas.transform_batch(pandas_floor)
def ceil(self, freq, *args, **kwargs) -> "pp.Series":
"""
Perform ceil operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to round the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
nonexistent : 'shift_forward', 'shift_backward, 'NaT', timedelta, default 'raise'
A nonexistent time does not exist in a particular timezone
where clocks moved forward due to DST.
- 'shift_forward' will shift the nonexistent time forward to the
closest existing time
- 'shift_backward' will shift the nonexistent time backward to the
closest existing time
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are
nonexistent times
.. note:: this option only works with pandas 0.24.0+
Returns
-------
Series
a Series with the same index for a Series.
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> series = pp.Series(pd.date_range('1/1/2018 11:59:00', periods=3, freq='min'))
>>> series
0 2018-01-01 11:59:00
1 2018-01-01 12:00:00
2 2018-01-01 12:01:00
dtype: datetime64[ns]
>>> series.dt.ceil("H")
0 2018-01-01 12:00:00
1 2018-01-01 12:00:00
2 2018-01-01 13:00:00
dtype: datetime64[ns]
"""
def pandas_ceil(s) -> "pp.Series[np.datetime64]":
return s.dt.ceil(freq, *args, **kwargs)
return self._data.koalas.transform_batch(pandas_ceil)
def month_name(self, locale=None) -> "pp.Series":
"""
Return the month names of the series with specified locale.
Parameters
----------
locale : str, optional
Locale determining the language in which to return the month name.
Default is English locale.
Returns
-------
Series
Series of month names.
Examples
--------
>>> series = pp.Series(pd.date_range(start='2018-01', freq='M', periods=3))
>>> series
0 2018-01-31
1 2018-02-28
2 2018-03-31
dtype: datetime64[ns]
>>> series.dt.month_name()
0 January
1 February
2 March
dtype: object
"""
def pandas_month_name(s) -> "pp.Series[str]":
return s.dt.month_name(locale=locale)
return self._data.koalas.transform_batch(pandas_month_name)
def day_name(self, locale=None) -> "pp.Series":
"""
Return the day names of the series with specified locale.
Parameters
----------
locale : str, optional
Locale determining the language in which to return the day name.
Default is English locale.
Returns
-------
Series
Series of day names.
Examples
--------
>>> series = pp.Series(pd.date_range(start='2018-01-01', freq='D', periods=3))
>>> series
0 2018-01-01
1 2018-01-02
2 2018-01-03
dtype: datetime64[ns]
>>> series.dt.day_name()
0 Monday
1 Tuesday
2 Wednesday
dtype: object
"""
def pandas_day_name(s) -> "pp.Series[str]":
return s.dt.day_name(locale=locale)
return self._data.koalas.transform_batch(pandas_day_name)

View file

@ -0,0 +1,106 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Exceptions/Errors used in Koalas.
"""
class DataError(Exception):
pass
class SparkPandasIndexingError(Exception):
pass
def code_change_hint(pandas_function, spark_target_function):
if pandas_function is not None and spark_target_function is not None:
return "You are trying to use pandas function {}, use spark function {}".format(
pandas_function, spark_target_function
)
elif pandas_function is not None and spark_target_function is None:
return (
"You are trying to use pandas function {}, checkout the spark "
"user guide to find a relevant function"
).format(pandas_function)
elif pandas_function is None and spark_target_function is not None:
return "Use spark function {}".format(spark_target_function)
else: # both none
return "Checkout the spark user guide to find a relevant function"
class SparkPandasNotImplementedError(NotImplementedError):
def __init__(self, pandas_function=None, spark_target_function=None, description=""):
self.pandas_source = pandas_function
self.spark_target = spark_target_function
hint = code_change_hint(pandas_function, spark_target_function)
if len(description) > 0:
description += " " + hint
else:
description = hint
super().__init__(description)
class PandasNotImplementedError(NotImplementedError):
def __init__(
self,
class_name,
method_name=None,
arg_name=None,
property_name=None,
deprecated=False,
reason="",
):
assert (method_name is None) != (property_name is None)
self.class_name = class_name
self.method_name = method_name
self.arg_name = arg_name
if method_name is not None:
if arg_name is not None:
msg = "The method `{0}.{1}()` does not support `{2}` parameter. {3}".format(
class_name, method_name, arg_name, reason
)
else:
if deprecated:
msg = (
"The method `{0}.{1}()` is deprecated in pandas and will therefore "
+ "not be supported in Koalas. {2}"
).format(class_name, method_name, reason)
else:
if reason == "":
reason = " yet."
else:
reason = ". " + reason
msg = "The method `{0}.{1}()` is not implemented{2}".format(
class_name, method_name, reason
)
else:
if deprecated:
msg = (
"The property `{0}.{1}()` is deprecated in pandas and will therefore "
+ "not be supported in Koalas. {2}"
).format(class_name, property_name, reason)
else:
if reason == "":
reason = " yet."
else:
reason = ". " + reason
msg = "The property `{0}.{1}()` is not implemented{2}".format(
class_name, property_name, reason
)
super().__init__(msg)

View file

@ -0,0 +1,342 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
class CachedAccessor:
"""
Custom property-like object.
A descriptor for caching accessors:
Parameters
----------
name : str
Namespace that accessor's methods, properties, etc will be accessed under, e.g. "foo" for a
dataframe accessor yields the accessor ``df.foo``
accessor: cls
Class with the extension methods.
Notes
-----
For accessor, the class's __init__ method assumes that you are registering an accessor for one
of ``Series``, ``DataFrame``, or ``Index``.
This object is not meant to be instantiated directly. Instead, use register_dataframe_accessor,
register_series_accessor, or register_index_accessor.
The Koalas accessor is modified based on pandas.core.accessor.
"""
def __init__(self, name, accessor):
self._name = name
self._accessor = accessor
def __get__(self, obj, cls):
if obj is None:
return self._accessor
accessor_obj = self._accessor(obj)
object.__setattr__(obj, self._name, accessor_obj)
return accessor_obj
def _register_accessor(name, cls):
"""
Register a custom accessor on {klass} objects.
Parameters
----------
name : str
Name under which the accessor should be registered. A warning is issued if this name
conflicts with a preexisting attribute.
Returns
-------
callable
A class decorator.
See Also
--------
register_dataframe_accessor: Register a custom accessor on DataFrame objects
register_series_accessor: Register a custom accessor on Series objects
register_index_accessor: Register a custom accessor on Index objects
Notes
-----
When accessed, your accessor will be initialiazed with the Koalas object the user is interacting
with. The code signature must be:
.. code-block:: python
def __init__(self, koalas_obj):
# constructor logic
...
In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to
raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more
frequently used to annotate when a value's datatype is unexpected for a given method/function.
Ultimately, you can structure this however you like, but Koalas would likely do something like
this:
>>> pp.Series(['a', 'b']).dt
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
Note: This function is not meant to be used directly - instead, use register_dataframe_accessor,
register_series_accessor, or register_index_accessor.
"""
def decorator(accessor):
if hasattr(cls, name):
msg = (
"registration of accessor {0} under name '{1}' for type {2} is overriding "
"a preexisting attribute with the same name.".format(accessor, name, cls.__name__)
)
warnings.warn(
msg, UserWarning, stacklevel=2,
)
setattr(cls, name, CachedAccessor(name, accessor))
return accessor
return decorator
def register_dataframe_accessor(name):
"""
Register a custom accessor with a DataFrame
Parameters
----------
name : str
name used when calling the accessor after its registered
Returns
-------
callable
A class decorator.
See Also
--------
register_series_accessor: Register a custom accessor on Series objects
register_index_accessor: Register a custom accessor on Index objects
Notes
-----
When accessed, your accessor will be initialiazed with the Koalas object the user is interacting
with. The accessor's init method should always ingest the object being accessed. See the
examples for the init signature.
In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to
raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more
frequently used to annotate when a value's datatype is unexpected for a given method/function.
Ultimately, you can structure this however you like, but Koalas would likely do something like
this:
>>> pp.Series(['a', 'b']).dt
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
Examples
--------
In your library code::
from pyspark.pandas.extensions import register_dataframe_accessor
@register_dataframe_accessor("geo")
class GeoAccessor:
def __init__(self, koalas_obj):
self._obj = koalas_obj
# other constructor logic
@property
def center(self):
# return the geographic center point of this DataFrame
lat = self._obj.latitude
lon = self._obj.longitude
return (float(lon.mean()), float(lat.mean()))
def plot(self):
# plot this array's data on a map
pass
Then, in an ipython session::
>>> ## Import if the accessor is in the other file.
>>> # from my_ext_lib import GeoAccessor
>>> kdf = pp.DataFrame({"longitude": np.linspace(0,10),
... "latitude": np.linspace(0, 20)})
>>> kdf.geo.center # doctest: +SKIP
(5.0, 10.0)
>>> kdf.geo.plot() # doctest: +SKIP
"""
from pyspark.pandas import DataFrame
return _register_accessor(name, DataFrame)
def register_series_accessor(name):
"""
Register a custom accessor with a Series object
Parameters
----------
name : str
name used when calling the accessor after its registered
Returns
-------
callable
A class decorator.
See Also
--------
register_dataframe_accessor: Register a custom accessor on DataFrame objects
register_index_accessor: Register a custom accessor on Index objects
Notes
-----
When accessed, your accessor will be initialiazed with the Koalas object the user is interacting
with. The code signature must be::
def __init__(self, koalas_obj):
# constructor logic
...
In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to
raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more
frequently used to annotate when a value's datatype is unexpected for a given method/function.
Ultimately, you can structure this however you like, but Koalas would likely do something like
this:
>>> pp.Series(['a', 'b']).dt
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
Examples
--------
In your library code::
from pyspark.pandas.extensions import register_series_accessor
@register_series_accessor("geo")
class GeoAccessor:
def __init__(self, koalas_obj):
self._obj = koalas_obj
@property
def is_valid(self):
# boolean check to see if series contains valid geometry
return True
Then, in an ipython session::
>>> ## Import if the accessor is in the other file.
>>> # from my_ext_lib import GeoAccessor
>>> kdf = pp.DataFrame({"longitude": np.linspace(0,10),
... "latitude": np.linspace(0, 20)})
>>> kdf.longitude.geo.is_valid # doctest: +SKIP
True
"""
from pyspark.pandas import Series
return _register_accessor(name, Series)
def register_index_accessor(name):
"""
Register a custom accessor with an Index
Parameters
----------
name : str
name used when calling the accessor after its registered
Returns
-------
callable
A class decorator.
See Also
--------
register_dataframe_accessor: Register a custom accessor on DataFrame objects
register_series_accessor: Register a custom accessor on Series objects
Notes
-----
When accessed, your accessor will be initialiazed with the Koalas object the user is interacting
with. The code signature must be::
def __init__(self, koalas_obj):
# constructor logic
...
In the pandas API, if data passed to your accessor has an incorrect dtype, it's recommended to
raise an ``AttributeError`` for consistency purposes. In Koalas, ``ValueError`` is more
frequently used to annotate when a value's datatype is unexpected for a given method/function.
Ultimately, you can structure this however you like, but Koalas would likely do something like
this:
>>> pp.Series(['a', 'b']).dt
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
Examples
--------
In your library code::
from pyspark.pandas.extensions import register_index_accessor
@register_index_accessor("foo")
class CustomAccessor:
def __init__(self, koalas_obj):
self._obj = koalas_obj
self.item = "baz"
@property
def bar(self):
# return item value
return self.item
Then, in an ipython session::
>>> ## Import if the accessor is in the other file.
>>> # from my_ext_lib import CustomAccessor
>>> kdf = pp.DataFrame({"longitude": np.linspace(0,10),
... "latitude": np.linspace(0, 20)})
>>> kdf.index.foo.bar # doctest: +SKIP
'baz'
"""
from pyspark.pandas import Index
return _register_accessor(name, Index)

11976
python/pyspark/pandas/frame.py Normal file

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.indexes.base import Index # noqa: F401
from pyspark.pandas.indexes.datetimes import DatetimeIndex # noqa: F401
from pyspark.pandas.indexes.multi import MultiIndex # noqa: F401
from pyspark.pandas.indexes.numeric import Float64Index, Int64Index # noqa: F401

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,188 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import partial
from typing import Any
import pandas as pd
from pandas.api.types import is_hashable
from pyspark import pandas as pp
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.missing.indexes import MissingPandasLikeCategoricalIndex
from pyspark.pandas.series import Series
class CategoricalIndex(Index):
"""
Index based on an underlying `Categorical`.
CategoricalIndex can only take on a limited,
and usually fixed, number of possible values (`categories`). Also,
it might have an order, but numerical operations
(additions, divisions, ...) are not possible.
Parameters
----------
data : array-like (1-dimensional)
The values of the categorical. If `categories` are given, values not in
`categories` will be replaced with NaN.
categories : index-like, optional
The categories for the categorical. Items need to be unique.
If the categories are not given here (and also not in `dtype`), they
will be inferred from the `data`.
ordered : bool, optional
Whether or not this categorical is treated as an ordered
categorical. If not given here or in `dtype`, the resulting
categorical will be unordered.
dtype : CategoricalDtype or "category", optional
If :class:`CategoricalDtype`, cannot be used together with
`categories` or `ordered`.
copy : bool, default False
Make a copy of input ndarray.
name : object, optional
Name to be stored in the index.
See Also
--------
Index : The base Koalas Index type.
Examples
--------
>>> pp.CategoricalIndex(["a", "b", "c", "a", "b", "c"]) # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'c', 'a', 'b', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
``CategoricalIndex`` can also be instantiated from a ``Categorical``:
>>> c = pd.Categorical(["a", "b", "c", "a", "b", "c"])
>>> pp.CategoricalIndex(c) # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'c', 'a', 'b', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
Ordered ``CategoricalIndex`` can have a min and max value.
>>> ci = pp.CategoricalIndex(
... ["a", "b", "c", "a", "b", "c"], ordered=True, categories=["c", "b", "a"]
... )
>>> ci # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'c', 'a', 'b', 'c'],
categories=['c', 'b', 'a'], ordered=True, dtype='category')
From a Series:
>>> s = pp.Series(["a", "b", "c", "a", "b", "c"], index=[10, 20, 30, 40, 50, 60])
>>> pp.CategoricalIndex(s) # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'c', 'a', 'b', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
From an Index:
>>> idx = pp.Index(["a", "b", "c", "a", "b", "c"])
>>> pp.CategoricalIndex(idx) # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'c', 'a', 'b', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
"""
def __new__(cls, data=None, categories=None, ordered=None, dtype=None, copy=False, name=None):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "category"
return Index(data, dtype=dtype, copy=copy, name=name)
return pp.from_pandas(
pd.CategoricalIndex(
data=data, categories=categories, ordered=ordered, dtype=dtype, name=name
)
)
@property
def codes(self) -> Index:
"""
The category codes of this categorical.
Codes are an Index of integers which are the positions of the actual
values in the categories Index.
There is no setter, use the other categorical methods and the normal item
setter to change values in the categorical.
Returns
-------
Index
A non-writable view of the `codes` Index.
Examples
--------
>>> idx = pp.CategoricalIndex(list("abbccc"))
>>> idx # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
>>> idx.codes
Int64Index([0, 1, 1, 2, 2, 2], dtype='int64')
"""
return self._with_new_scol(self.spark.column).rename(None)
@property
def categories(self) -> pd.Index:
"""
The categories of this categorical.
Examples
--------
>>> idx = pp.CategoricalIndex(list("abbccc"))
>>> idx # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
>>> idx.categories
Index(['a', 'b', 'c'], dtype='object')
"""
return self.dtype.categories
@categories.setter
def categories(self, categories):
raise NotImplementedError()
@property
def ordered(self) -> bool:
"""
Whether the categories have an ordered relationship.
Examples
--------
>>> idx = pp.CategoricalIndex(list("abbccc"))
>>> idx # doctest: +NORMALIZE_WHITESPACE
CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'],
categories=['a', 'b', 'c'], ordered=False, dtype='category')
>>> idx.ordered
False
"""
return self.dtype.ordered
def __getattr__(self, item: str) -> Any:
if hasattr(MissingPandasLikeCategoricalIndex, item):
property_or_func = getattr(MissingPandasLikeCategoricalIndex, item)
if isinstance(property_or_func, property):
return property_or_func.fget(self) # type: ignore
else:
return partial(property_or_func, self)
raise AttributeError("'CategoricalIndex' object has no attribute '{}'".format(item))

View file

@ -0,0 +1,742 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
from functools import partial
from typing import Any, Optional, Union
import pandas as pd
from pandas.api.types import is_hashable
from pyspark._globals import _NoValue
from pyspark import pandas as pp
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.missing.indexes import MissingPandasLikeDatetimeIndex
from pyspark.pandas.series import Series, first_series
from pyspark.pandas.utils import verify_temp_column_name
class DatetimeIndex(Index):
"""
Immutable ndarray-like of datetime64 data.
Parameters
----------
data : array-like (1-dimensional), optional
Optional datetime-like data to construct index with.
freq : str or pandas offset object, optional
One of pandas date offset strings or corresponding objects. The string
'infer' can be passed in order to set the frequency of the index as the
inferred frequency upon creation.
normalize : bool, default False
Normalize start/end dates to midnight before generating date range.
closed : {'left', 'right'}, optional
Set whether to include `start` and `end` that are on the
boundary. The default includes boundary points on either end.
ambiguous : 'infer', bool-ndarray, 'NaT', default 'raise'
When clocks moved backward due to DST, ambiguous times may arise.
For example in Central European Time (UTC+01), when going from 03:00
DST to 02:00 non-DST, 02:30:00 local time occurs both at 00:30:00 UTC
and at 01:30:00 UTC. In such a situation, the `ambiguous` parameter
dictates how ambiguous times should be handled.
- 'infer' will attempt to infer fall dst-transition hours based on
order
- bool-ndarray where True signifies a DST time, False signifies a
non-DST time (note that this flag is only applicable for ambiguous
times)
- 'NaT' will return NaT where there are ambiguous times
- 'raise' will raise an AmbiguousTimeError if there are ambiguous times.
dayfirst : bool, default False
If True, parse dates in `data` with the day first order.
yearfirst : bool, default False
If True parse dates in `data` with the year first order.
dtype : numpy.dtype or str, default None
Note that the only NumPy dtype allowed is datetime64[ns].
copy : bool, default False
Make a copy of input ndarray.
name : label, default None
Name to be stored in the index.
See Also
--------
Index : The base pandas Index type.
to_datetime : Convert argument to datetime.
Examples
--------
>>> pp.DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'])
DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'], dtype='datetime64[ns]', freq=None)
From a Series:
>>> from datetime import datetime
>>> s = pp.Series([datetime(2021, 3, 1), datetime(2021, 3, 2)], index=[10, 20])
>>> pp.DatetimeIndex(s)
DatetimeIndex(['2021-03-01', '2021-03-02'], dtype='datetime64[ns]', freq=None)
From an Index:
>>> idx = pp.DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'])
>>> pp.DatetimeIndex(idx)
DatetimeIndex(['1970-01-01', '1970-01-01', '1970-01-01'], dtype='datetime64[ns]', freq=None)
"""
def __new__(
cls,
data=None,
freq=_NoValue,
normalize=False,
closed=None,
ambiguous="raise",
dayfirst=False,
yearfirst=False,
dtype=None,
copy=False,
name=None,
):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "datetime64[ns]"
return Index(data, dtype=dtype, copy=copy, name=name)
kwargs = dict(
data=data,
normalize=normalize,
closed=closed,
ambiguous=ambiguous,
dayfirst=dayfirst,
yearfirst=yearfirst,
dtype=dtype,
copy=copy,
name=name,
)
if freq is not _NoValue:
kwargs["freq"] = freq
return pp.from_pandas(pd.DatetimeIndex(**kwargs))
def __getattr__(self, item: str) -> Any:
if hasattr(MissingPandasLikeDatetimeIndex, item):
property_or_func = getattr(MissingPandasLikeDatetimeIndex, item)
if isinstance(property_or_func, property):
return property_or_func.fget(self) # type: ignore
else:
return partial(property_or_func, self)
raise AttributeError("'DatetimeIndex' object has no attribute '{}'".format(item))
# Properties
@property
def year(self) -> Index:
"""
The year of the datetime.
"""
return Index(self.to_series().dt.year)
@property
def month(self) -> Index:
"""
The month of the timestamp as January = 1 December = 12.
"""
return Index(self.to_series().dt.month)
@property
def day(self) -> Index:
"""
The days of the datetime.
"""
return Index(self.to_series().dt.day)
@property
def hour(self) -> Index:
"""
The hours of the datetime.
"""
return Index(self.to_series().dt.hour)
@property
def minute(self) -> Index:
"""
The minutes of the datetime.
"""
return Index(self.to_series().dt.minute)
@property
def second(self) -> Index:
"""
The seconds of the datetime.
"""
return Index(self.to_series().dt.second)
@property
def microsecond(self) -> Index:
"""
The microseconds of the datetime.
"""
return Index(self.to_series().dt.microsecond)
@property
def week(self) -> Index:
"""
The week ordinal of the year.
"""
return Index(self.to_series().dt.week)
@property
def weekofyear(self) -> Index:
return Index(self.to_series().dt.weekofyear)
weekofyear.__doc__ = week.__doc__
@property
def dayofweek(self) -> Index:
"""
The day of the week with Monday=0, Sunday=6.
Return the day of the week. It is assumed the week starts on
Monday, which is denoted by 0 and ends on Sunday which is denoted
by 6. This method is available on both Series with datetime
values (using the `dt` accessor) or DatetimeIndex.
Returns
-------
Series or Index
Containing integers indicating the day number.
See Also
--------
Series.dt.dayofweek : Alias.
Series.dt.weekday : Alias.
Series.dt.day_name : Returns the name of the day of the week.
Examples
--------
>>> idx = pp.date_range('2016-12-31', '2017-01-08', freq='D')
>>> idx.dayofweek
Int64Index([5, 6, 0, 1, 2, 3, 4, 5, 6], dtype='int64')
"""
return Index(self.to_series().dt.dayofweek)
@property
def day_of_week(self) -> Index:
return self.dayofweek
day_of_week.__doc__ = dayofweek.__doc__
@property
def weekday(self) -> Index:
return Index(self.to_series().dt.weekday)
weekday.__doc__ = dayofweek.__doc__
@property
def dayofyear(self) -> Index:
"""
The ordinal day of the year.
"""
return Index(self.to_series().dt.dayofyear)
@property
def day_of_year(self) -> Index:
return self.dayofyear
day_of_year.__doc__ = dayofyear.__doc__
@property
def quarter(self) -> Index:
"""
The quarter of the date.
"""
return Index(self.to_series().dt.quarter)
@property
def is_month_start(self) -> Index:
"""
Indicates whether the date is the first day of the month.
Returns
-------
Index
Returns a Index with boolean values
See Also
--------
is_month_end : Return a boolean indicating whether the date
is the last day of the month.
Examples
--------
>>> idx = pp.date_range("2018-02-27", periods=3)
>>> idx.is_month_start
Index([False, False, True], dtype='object')
"""
return Index(self.to_series().dt.is_month_start)
@property
def is_month_end(self) -> Index:
"""
Indicates whether the date is the last day of the month.
Returns
-------
Index
Returns a Index with boolean values.
See Also
--------
is_month_start : Return a boolean indicating whether the date
is the first day of the month.
Examples
--------
>>> idx = pp.date_range("2018-02-27", periods=3)
>>> idx.is_month_end
Index([False, True, False], dtype='object')
"""
return Index(self.to_series().dt.is_month_end)
@property
def is_quarter_start(self) -> Index:
"""
Indicator for whether the date is the first day of a quarter.
Returns
-------
is_quarter_start : Index
Returns an Index with boolean values.
See Also
--------
quarter : Return the quarter of the date.
is_quarter_end : Similar property for indicating the quarter start.
Examples
--------
>>> idx = pp.date_range('2017-03-30', periods=4)
>>> idx.is_quarter_start
Index([False, False, True, False], dtype='object')
"""
return Index(self.to_series().dt.is_quarter_start)
@property
def is_quarter_end(self) -> Index:
"""
Indicator for whether the date is the last day of a quarter.
Returns
-------
is_quarter_end : Index
Returns an Index with boolean values.
See Also
--------
quarter : Return the quarter of the date.
is_quarter_start : Similar property indicating the quarter start.
Examples
--------
>>> idx = pp.date_range('2017-03-30', periods=4)
>>> idx.is_quarter_end
Index([False, True, False, False], dtype='object')
"""
return Index(self.to_series().dt.is_quarter_end)
@property
def is_year_start(self) -> Index:
"""
Indicate whether the date is the first day of a year.
Returns
-------
Index
Returns an Index with boolean values.
See Also
--------
is_year_end : Similar property indicating the last day of the year.
Examples
--------
>>> idx = pp.date_range("2017-12-30", periods=3)
>>> idx.is_year_start
Index([False, False, True], dtype='object')
"""
return Index(self.to_series().dt.is_year_start)
@property
def is_year_end(self) -> Index:
"""
Indicate whether the date is the last day of the year.
Returns
-------
Index
Returns an Index with boolean values.
See Also
--------
is_year_start : Similar property indicating the start of the year.
Examples
--------
>>> idx = pp.date_range("2017-12-30", periods=3)
>>> idx.is_year_end
Index([False, True, False], dtype='object')
"""
return Index(self.to_series().dt.is_year_end)
@property
def is_leap_year(self) -> Index:
"""
Boolean indicator if the date belongs to a leap year.
A leap year is a year, which has 366 days (instead of 365) including
29th of February as an intercalary day.
Leap years are years which are multiples of four with the exception
of years divisible by 100 but not by 400.
Returns
-------
Index
Booleans indicating if dates belong to a leap year.
Examples
--------
>>> idx = pp.date_range("2012-01-01", "2015-01-01", freq="Y")
>>> idx.is_leap_year
Index([True, False, False], dtype='object')
"""
return Index(self.to_series().dt.is_leap_year)
@property
def daysinmonth(self) -> Index:
"""
The number of days in the month.
"""
return Index(self.to_series().dt.daysinmonth)
@property
def days_in_month(self) -> Index:
return Index(self.to_series().dt.days_in_month)
days_in_month.__doc__ = daysinmonth.__doc__
# Methods
def ceil(self, freq, *args, **kwargs) -> "DatetimeIndex":
"""
Perform ceil operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to ceil the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
Returns
-------
DatetimeIndex
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> rng = pp.date_range('1/1/2018 11:59:00', periods=3, freq='min')
>>> rng.ceil('H') # doctest: +NORMALIZE_WHITESPACE
DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00',
'2018-01-01 13:00:00'],
dtype='datetime64[ns]', freq=None)
"""
disallow_nanoseconds(freq)
return DatetimeIndex(self.to_series().dt.ceil(freq, *args, **kwargs))
def floor(self, freq, *args, **kwargs) -> "DatetimeIndex":
"""
Perform floor operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to floor the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
Returns
-------
DatetimeIndex
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> rng = pp.date_range('1/1/2018 11:59:00', periods=3, freq='min')
>>> rng.floor("H") # doctest: +NORMALIZE_WHITESPACE
DatetimeIndex(['2018-01-01 11:00:00', '2018-01-01 12:00:00',
'2018-01-01 12:00:00'],
dtype='datetime64[ns]', freq=None)
"""
disallow_nanoseconds(freq)
return DatetimeIndex(self.to_series().dt.floor(freq, *args, **kwargs))
def round(self, freq, *args, **kwargs) -> "DatetimeIndex":
"""
Perform round operation on the data to the specified freq.
Parameters
----------
freq : str or Offset
The frequency level to round the index to. Must be a fixed
frequency like 'S' (second) not 'ME' (month end).
Returns
-------
DatetimeIndex
Raises
------
ValueError if the `freq` cannot be converted.
Examples
--------
>>> rng = pp.date_range('1/1/2018 11:59:00', periods=3, freq='min')
>>> rng.round("H") # doctest: +NORMALIZE_WHITESPACE
DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00',
'2018-01-01 12:00:00'],
dtype='datetime64[ns]', freq=None)
"""
disallow_nanoseconds(freq)
return DatetimeIndex(self.to_series().dt.round(freq, *args, **kwargs))
def month_name(self, locale: Optional[str] = None) -> Index:
"""
Return the month names of the DatetimeIndex with specified locale.
Parameters
----------
locale : str, optional
Locale determining the language in which to return the month name.
Default is English locale.
Returns
-------
Index
Index of month names.
Examples
--------
>>> idx = pp.date_range(start='2018-01', freq='M', periods=3)
>>> idx.month_name()
Index(['January', 'February', 'March'], dtype='object')
"""
return Index(self.to_series().dt.month_name(locale))
def day_name(self, locale: Optional[str] = None) -> Index:
"""
Return the day names of the series with specified locale.
Parameters
----------
locale : str, optional
Locale determining the language in which to return the day name.
Default is English locale.
Returns
-------
Index
Index of day names.
Examples
--------
>>> idx = pp.date_range(start='2018-01-01', freq='D', periods=3)
>>> idx.day_name()
Index(['Monday', 'Tuesday', 'Wednesday'], dtype='object')
"""
return Index(self.to_series().dt.day_name(locale))
def normalize(self) -> "DatetimeIndex":
"""
Convert times to midnight.
The time component of the date-time is converted to midnight i.e.
00:00:00. This is useful in cases, when the time does not matter.
Length is unaltered. The timezones are unaffected.
This method is available on Series with datetime values under
the ``.dt`` accessor.
Returns
-------
DatetimeIndex
The same type as the original data.
See Also
--------
floor : Floor the series to the specified freq.
ceil : Ceil the series to the specified freq.
round : Round the series to the specified freq.
Examples
--------
>>> idx = pp.date_range(start='2014-08-01 10:00', freq='H', periods=3)
>>> idx.normalize()
DatetimeIndex(['2014-08-01', '2014-08-01', '2014-08-01'], dtype='datetime64[ns]', freq=None)
"""
return DatetimeIndex(self.to_series().dt.normalize())
def strftime(self, date_format: str) -> Index:
"""
Convert to a string Index using specified date_format.
Return an Index of formatted strings specified by date_format, which
supports the same string format as the python standard library. Details
of the string format can be found in python string format
doc.
Parameters
----------
date_format : str
Date format string (e.g. "%%Y-%%m-%%d").
Returns
-------
Index
Index of formatted strings.
See Also
--------
normalize : Return series with times to midnight.
round : Round the series to the specified freq.
floor : Floor the series to the specified freq.
Examples
--------
>>> idx = pp.date_range(pd.Timestamp("2018-03-10 09:00"), periods=3, freq='s')
>>> idx.strftime('%B %d, %Y, %r') # doctest: +NORMALIZE_WHITESPACE
Index(['March 10, 2018, 09:00:00 AM', 'March 10, 2018, 09:00:01 AM',
'March 10, 2018, 09:00:02 AM'],
dtype='object')
"""
return Index(self.to_series().dt.strftime(date_format))
def indexer_between_time(
self,
start_time: Union[datetime.time, str],
end_time: Union[datetime.time, str],
include_start: bool = True,
include_end: bool = True,
) -> Index:
"""
Return index locations of values between particular times of day
(e.g., 9:00-9:30AM).
Parameters
----------
start_time, end_time : datetime.time, str
Time passed either as object (datetime.time) or as string in
appropriate format ("%H:%M", "%H%M", "%I:%M%p", "%I%M%p",
"%H:%M:%S", "%H%M%S", "%I:%M:%S%p","%I%M%S%p").
include_start : bool, default True
include_end : bool, default True
Returns
-------
values_between_time : Index of integers
Examples
--------
>>> kidx = pp.date_range("2000-01-01", periods=3, freq="T")
>>> kidx # doctest: +NORMALIZE_WHITESPACE
DatetimeIndex(['2000-01-01 00:00:00', '2000-01-01 00:01:00',
'2000-01-01 00:02:00'],
dtype='datetime64[ns]', freq=None)
>>> kidx.indexer_between_time("00:01", "00:02").sort_values()
Int64Index([1, 2], dtype='int64')
>>> kidx.indexer_between_time("00:01", "00:02", include_end=False)
Int64Index([1], dtype='int64')
>>> kidx.indexer_between_time("00:01", "00:02", include_start=False)
Int64Index([2], dtype='int64')
"""
def pandas_between_time(pdf) -> pp.DataFrame[int]:
return pdf.between_time(start_time, end_time, include_start, include_end)
kdf = self.to_frame()[[]]
id_column_name = verify_temp_column_name(kdf, "__id_column__")
kdf = kdf.koalas.attach_id_column("distributed-sequence", id_column_name)
with pp.option_context("compute.default_index_type", "distributed"):
# The attached index in the statement below will be dropped soon,
# so we enforce “distributed” default index type
kdf = kdf.koalas.apply_batch(pandas_between_time)
return pp.Index(first_series(kdf).rename(self.name))
def indexer_at_time(self, time: Union[datetime.time, str], asof: bool = False) -> Index:
"""
Return index locations of values at particular time of day
(e.g. 9:30AM).
Parameters
----------
time : datetime.time or str
Time passed in either as object (datetime.time) or as string in
appropriate format ("%H:%M", "%H%M", "%I:%M%p", "%I%M%p",
"%H:%M:%S", "%H%M%S", "%I:%M:%S%p", "%I%M%S%p").
Returns
-------
values_at_time : Index of integers
Examples
--------
>>> kidx = pp.date_range("2000-01-01", periods=3, freq="T")
>>> kidx # doctest: +NORMALIZE_WHITESPACE
DatetimeIndex(['2000-01-01 00:00:00', '2000-01-01 00:01:00',
'2000-01-01 00:02:00'],
dtype='datetime64[ns]', freq=None)
>>> kidx.indexer_at_time("00:00")
Int64Index([0], dtype='int64')
>>> kidx.indexer_at_time("00:01")
Int64Index([1], dtype='int64')
"""
if asof:
raise NotImplementedError("'asof' argument is not supported")
def pandas_at_time(pdf) -> pp.DataFrame[int]:
return pdf.at_time(time, asof)
kdf = self.to_frame()[[]]
id_column_name = verify_temp_column_name(kdf, "__id_column__")
kdf = kdf.koalas.attach_id_column("distributed-sequence", id_column_name)
with pp.option_context("compute.default_index_type", "distributed"):
# The attached index in the statement below will be dropped soon,
# so we enforce “distributed” default index type
kdf = kdf.koalas.apply_batch(pandas_at_time)
return pp.Index(first_series(kdf).rename(self.name))
def disallow_nanoseconds(freq):
if freq in ["N", "ns"]:
raise ValueError("nanoseconds is not supported")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,147 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from pandas.api.types import is_hashable
from pyspark import pandas as pp
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.series import Series
class NumericIndex(Index):
"""
Provide numeric type operations.
This is an abstract class.
"""
pass
class IntegerIndex(NumericIndex):
"""
This is an abstract class for Int64Index.
"""
pass
class Int64Index(IntegerIndex):
"""
Immutable sequence used for indexing and alignment. The basic object
storing axis labels for all pandas objects. Int64Index is a special case
of `Index` with purely integer labels.
Parameters
----------
data : array-like (1-dimensional)
dtype : NumPy dtype (default: int64)
copy : bool
Make a copy of input ndarray.
name : object
Name to be stored in the index.
See Also
--------
Index : The base Koalas Index type.
Float64Index : A special case of :class:`Index` with purely float labels.
Notes
-----
An Index instance can **only** contain hashable objects.
Examples
--------
>>> pp.Int64Index([1, 2, 3])
Int64Index([1, 2, 3], dtype='int64')
From a Series:
>>> s = pp.Series([1, 2, 3], index=[10, 20, 30])
>>> pp.Int64Index(s)
Int64Index([1, 2, 3], dtype='int64')
From an Index:
>>> idx = pp.Index([1, 2, 3])
>>> pp.Int64Index(idx)
Int64Index([1, 2, 3], dtype='int64')
"""
def __new__(cls, data=None, dtype=None, copy=False, name=None):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "int64"
return Index(data, dtype=dtype, copy=copy, name=name)
return pp.from_pandas(pd.Int64Index(data=data, dtype=dtype, copy=copy, name=name))
class Float64Index(NumericIndex):
"""
Immutable sequence used for indexing and alignment. The basic object
storing axis labels for all pandas objects. Float64Index is a special case
of `Index` with purely float labels.
Parameters
----------
data : array-like (1-dimensional)
dtype : NumPy dtype (default: float64)
copy : bool
Make a copy of input ndarray.
name : object
Name to be stored in the index.
See Also
--------
Index : The base Koalas Index type.
Int64Index : A special case of :class:`Index` with purely integer labels.
Notes
-----
An Index instance can **only** contain hashable objects.
Examples
--------
>>> pp.Float64Index([1.0, 2.0, 3.0])
Float64Index([1.0, 2.0, 3.0], dtype='float64')
From a Series:
>>> s = pp.Series([1, 2, 3], index=[10, 20, 30])
>>> pp.Float64Index(s)
Float64Index([1.0, 2.0, 3.0], dtype='float64')
From an Index:
>>> idx = pp.Index([1, 2, 3])
>>> pp.Float64Index(idx)
Float64Index([1.0, 2.0, 3.0], dtype='float64')
"""
def __new__(cls, data=None, dtype=None, copy=False, name=None):
if not is_hashable(name):
raise TypeError("Index.name must be a hashable type")
if isinstance(data, (Series, Index)):
if dtype is None:
dtype = "float64"
return Index(data, dtype=dtype, copy=copy, name=name)
return pp.from_pandas(pd.Float64Index(data=data, dtype=dtype, copy=copy, name=name))

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.exceptions import PandasNotImplementedError
def unsupported_function(class_name, method_name, deprecated=False, reason=""):
def unsupported_function(*args, **kwargs):
raise PandasNotImplementedError(
class_name=class_name, method_name=method_name, reason=reason
)
def deprecated_function(*args, **kwargs):
raise PandasNotImplementedError(
class_name=class_name, method_name=method_name, deprecated=deprecated, reason=reason
)
return deprecated_function if deprecated else unsupported_function
def unsupported_property(class_name, property_name, deprecated=False, reason=""):
@property
def unsupported_property(self):
raise PandasNotImplementedError(
class_name=class_name, property_name=property_name, reason=reason
)
@property
def deprecated_property(self):
raise PandasNotImplementedError(
class_name=class_name, property_name=property_name, deprecated=deprecated, reason=reason
)
return deprecated_property if deprecated else unsupported_property

View file

@ -0,0 +1,59 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
memory_usage = lambda f: f(
"memory_usage",
reason="Unlike pandas, most DataFrames are not materialized in memory in Spark "
"(and Koalas), and as a result memory_usage() does not do what you intend it "
"to do. Use Spark's web UI to monitor disk and memory usage of your application.",
)
array = lambda f: f(
"array", reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead."
)
to_pickle = lambda f: f(
"to_pickle",
reason="For storage, we encourage you to use Delta or Parquet, instead of Python pickle "
"format.",
)
to_xarray = lambda f: f(
"to_xarray",
reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.",
)
to_list = lambda f: f(
"to_list",
reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.",
)
tolist = lambda f: f(
"tolist", reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead."
)
__iter__ = lambda f: f(
"__iter__",
reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.",
)
duplicated = lambda f: f(
"duplicated",
reason="'duplicated' API returns np.ndarray and the data size is too large."
"You can just use DataFrame.deduplicated instead",
)

View file

@ -0,0 +1,98 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion
import pandas as pd
from pyspark.pandas.missing import unsupported_function, unsupported_property, common
def _unsupported_function(method_name, deprecated=False, reason=""):
return unsupported_function(
class_name="pd.DataFrame", method_name=method_name, deprecated=deprecated, reason=reason
)
def _unsupported_property(property_name, deprecated=False, reason=""):
return unsupported_property(
class_name="pd.DataFrame", property_name=property_name, deprecated=deprecated, reason=reason
)
class _MissingPandasLikeDataFrame(object):
# Functions
asfreq = _unsupported_function("asfreq")
asof = _unsupported_function("asof")
boxplot = _unsupported_function("boxplot")
combine = _unsupported_function("combine")
combine_first = _unsupported_function("combine_first")
compare = _unsupported_function("compare")
convert_dtypes = _unsupported_function("convert_dtypes")
corrwith = _unsupported_function("corrwith")
cov = _unsupported_function("cov")
ewm = _unsupported_function("ewm")
infer_objects = _unsupported_function("infer_objects")
interpolate = _unsupported_function("interpolate")
lookup = _unsupported_function("lookup")
mode = _unsupported_function("mode")
reorder_levels = _unsupported_function("reorder_levels")
resample = _unsupported_function("resample")
set_axis = _unsupported_function("set_axis")
slice_shift = _unsupported_function("slice_shift")
to_feather = _unsupported_function("to_feather")
to_gbq = _unsupported_function("to_gbq")
to_hdf = _unsupported_function("to_hdf")
to_period = _unsupported_function("to_period")
to_sql = _unsupported_function("to_sql")
to_stata = _unsupported_function("to_stata")
to_timestamp = _unsupported_function("to_timestamp")
tshift = _unsupported_function("tshift")
tz_convert = _unsupported_function("tz_convert")
tz_localize = _unsupported_function("tz_localize")
# Deprecated functions
convert_objects = _unsupported_function("convert_objects", deprecated=True)
select = _unsupported_function("select", deprecated=True)
to_panel = _unsupported_function("to_panel", deprecated=True)
get_values = _unsupported_function("get_values", deprecated=True)
compound = _unsupported_function("compound", deprecated=True)
reindex_axis = _unsupported_function("reindex_axis", deprecated=True)
# Functions we won't support.
to_pickle = common.to_pickle(_unsupported_function)
memory_usage = common.memory_usage(_unsupported_function)
to_xarray = common.to_xarray(_unsupported_function)
if LooseVersion(pd.__version__) < LooseVersion("1.0"):
# Deprecated properties
blocks = _unsupported_property("blocks", deprecated=True)
ftypes = _unsupported_property("ftypes", deprecated=True)
is_copy = _unsupported_property("is_copy", deprecated=True)
ix = _unsupported_property("ix", deprecated=True)
# Deprecated functions
as_blocks = _unsupported_function("as_blocks", deprecated=True)
as_matrix = _unsupported_function("as_matrix", deprecated=True)
clip_lower = _unsupported_function("clip_lower", deprecated=True)
clip_upper = _unsupported_function("clip_upper", deprecated=True)
get_ftype_counts = _unsupported_function("get_ftype_counts", deprecated=True)
get_value = _unsupported_function("get_value", deprecated=True)
set_value = _unsupported_function("set_value", deprecated=True)
to_dense = _unsupported_function("to_dense", deprecated=True)
to_sparse = _unsupported_function("to_sparse", deprecated=True)
to_msgpack = _unsupported_function("to_msgpack", deprecated=True)

View file

@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.missing import unsupported_function, unsupported_property
def _unsupported_function(method_name, deprecated=False, reason=""):
return unsupported_function(
class_name="pd.groupby.GroupBy",
method_name=method_name,
deprecated=deprecated,
reason=reason,
)
def _unsupported_property(property_name, deprecated=False, reason=""):
return unsupported_property(
class_name="pd.groupby.GroupBy",
property_name=property_name,
deprecated=deprecated,
reason=reason,
)
class MissingPandasLikeDataFrameGroupBy(object):
# Properties
corr = _unsupported_property("corr")
corrwith = _unsupported_property("corrwith")
cov = _unsupported_property("cov")
dtypes = _unsupported_property("dtypes")
groups = _unsupported_property("groups")
hist = _unsupported_property("hist")
indices = _unsupported_property("indices")
mad = _unsupported_property("mad")
ngroups = _unsupported_property("ngroups")
plot = _unsupported_property("plot")
quantile = _unsupported_property("quantile")
skew = _unsupported_property("skew")
tshift = _unsupported_property("tshift")
# Deprecated properties
take = _unsupported_property("take", deprecated=True)
# Functions
boxplot = _unsupported_function("boxplot")
ngroup = _unsupported_function("ngroup")
nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
sem = _unsupported_function("sem")
class MissingPandasLikeSeriesGroupBy(object):
# Properties
corr = _unsupported_property("corr")
cov = _unsupported_property("cov")
dtype = _unsupported_property("dtype")
groups = _unsupported_property("groups")
hist = _unsupported_property("hist")
indices = _unsupported_property("indices")
is_monotonic_decreasing = _unsupported_property("is_monotonic_decreasing")
is_monotonic_increasing = _unsupported_property("is_monotonic_increasing")
mad = _unsupported_property("mad")
ngroups = _unsupported_property("ngroups")
plot = _unsupported_property("plot")
quantile = _unsupported_property("quantile")
skew = _unsupported_property("skew")
tshift = _unsupported_property("tshift")
# Deprecated properties
take = _unsupported_property("take", deprecated=True)
# Functions
agg = _unsupported_function("agg")
aggregate = _unsupported_function("aggregate")
describe = _unsupported_function("describe")
ngroup = _unsupported_function("ngroup")
nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
sem = _unsupported_function("sem")

View file

@ -0,0 +1,218 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion
import pandas as pd
from pyspark.pandas.missing import unsupported_function, unsupported_property, common
def _unsupported_function(method_name, deprecated=False, reason="", cls="Index"):
return unsupported_function(
class_name="pd.{}".format(cls),
method_name=method_name,
deprecated=deprecated,
reason=reason,
)
def _unsupported_property(property_name, deprecated=False, reason="", cls="Index"):
return unsupported_property(
class_name="pd.{}".format(cls),
property_name=property_name,
deprecated=deprecated,
reason=reason,
)
class MissingPandasLikeIndex(object):
# Properties
nbytes = _unsupported_property("nbytes")
# Functions
argsort = _unsupported_function("argsort")
asof_locs = _unsupported_function("asof_locs")
format = _unsupported_function("format")
get_indexer = _unsupported_function("get_indexer")
get_indexer_for = _unsupported_function("get_indexer_for")
get_indexer_non_unique = _unsupported_function("get_indexer_non_unique")
get_loc = _unsupported_function("get_loc")
get_slice_bound = _unsupported_function("get_slice_bound")
get_value = _unsupported_function("get_value")
groupby = _unsupported_function("groupby")
is_ = _unsupported_function("is_")
is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple")
join = _unsupported_function("join")
map = _unsupported_function("map")
putmask = _unsupported_function("putmask")
ravel = _unsupported_function("ravel")
reindex = _unsupported_function("reindex")
searchsorted = _unsupported_function("searchsorted")
slice_indexer = _unsupported_function("slice_indexer")
slice_locs = _unsupported_function("slice_locs")
sortlevel = _unsupported_function("sortlevel")
to_flat_index = _unsupported_function("to_flat_index")
to_native_types = _unsupported_function("to_native_types")
where = _unsupported_function("where")
# Deprecated functions
is_mixed = _unsupported_function("is_mixed")
get_values = _unsupported_function("get_values", deprecated=True)
set_value = _unsupported_function("set_value")
# Properties we won't support.
array = common.array(_unsupported_property)
duplicated = common.duplicated(_unsupported_property)
# Functions we won't support.
memory_usage = common.memory_usage(_unsupported_function)
__iter__ = common.__iter__(_unsupported_function)
if LooseVersion(pd.__version__) < LooseVersion("1.0"):
# Deprecated properties
strides = _unsupported_property("strides", deprecated=True)
data = _unsupported_property("data", deprecated=True)
itemsize = _unsupported_property("itemsize", deprecated=True)
base = _unsupported_property("base", deprecated=True)
flags = _unsupported_property("flags", deprecated=True)
# Deprecated functions
get_duplicates = _unsupported_function("get_duplicates", deprecated=True)
summary = _unsupported_function("summary", deprecated=True)
contains = _unsupported_function("contains", deprecated=True)
class MissingPandasLikeDatetimeIndex(MissingPandasLikeIndex):
# Properties
nanosecond = _unsupported_property("nanosecond", cls="DatetimeIndex")
date = _unsupported_property("date", cls="DatetimeIndex")
time = _unsupported_property("time", cls="DatetimeIndex")
timetz = _unsupported_property("timetz", cls="DatetimeIndex")
tz = _unsupported_property("tz", cls="DatetimeIndex")
freq = _unsupported_property("freq", cls="DatetimeIndex")
freqstr = _unsupported_property("freqstr", cls="DatetimeIndex")
inferred_freq = _unsupported_property("inferred_freq", cls="DatetimeIndex")
# Functions
snap = _unsupported_function("snap", cls="DatetimeIndex")
tz_convert = _unsupported_function("tz_convert", cls="DatetimeIndex")
tz_localize = _unsupported_function("tz_localize", cls="DatetimeIndex")
to_period = _unsupported_function("to_period", cls="DatetimeIndex")
to_perioddelta = _unsupported_function("to_perioddelta", cls="DatetimeIndex")
to_pydatetime = _unsupported_function("to_pydatetime", cls="DatetimeIndex")
mean = _unsupported_function("mean", cls="DatetimeIndex")
std = _unsupported_function("std", cls="DatetimeIndex")
class MissingPandasLikeCategoricalIndex(MissingPandasLikeIndex):
# Functions
rename_categories = _unsupported_function("rename_categories", cls="CategoricalIndex")
reorder_categories = _unsupported_function("reorder_categories", cls="CategoricalIndex")
add_categories = _unsupported_function("add_categories", cls="CategoricalIndex")
remove_categories = _unsupported_function("remove_categories", cls="CategoricalIndex")
remove_unused_categories = _unsupported_function(
"remove_unused_categories", cls="CategoricalIndex"
)
set_categories = _unsupported_function("set_categories", cls="CategoricalIndex")
as_ordered = _unsupported_function("as_ordered", cls="CategoricalIndex")
as_unordered = _unsupported_function("as_unordered", cls="CategoricalIndex")
map = _unsupported_function("map", cls="CategoricalIndex")
class MissingPandasLikeMultiIndex(object):
# Deprecated properties
strides = _unsupported_property("strides", deprecated=True)
data = _unsupported_property("data", deprecated=True)
itemsize = _unsupported_property("itemsize", deprecated=True)
# Functions
argsort = _unsupported_function("argsort")
asof_locs = _unsupported_function("asof_locs")
equal_levels = _unsupported_function("equal_levels")
factorize = _unsupported_function("factorize")
format = _unsupported_function("format")
get_indexer = _unsupported_function("get_indexer")
get_indexer_for = _unsupported_function("get_indexer_for")
get_indexer_non_unique = _unsupported_function("get_indexer_non_unique")
get_loc = _unsupported_function("get_loc")
get_loc_level = _unsupported_function("get_loc_level")
get_locs = _unsupported_function("get_locs")
get_slice_bound = _unsupported_function("get_slice_bound")
get_value = _unsupported_function("get_value")
groupby = _unsupported_function("groupby")
is_ = _unsupported_function("is_")
is_lexsorted = _unsupported_function("is_lexsorted")
is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple")
join = _unsupported_function("join")
map = _unsupported_function("map")
putmask = _unsupported_function("putmask")
ravel = _unsupported_function("ravel")
reindex = _unsupported_function("reindex")
remove_unused_levels = _unsupported_function("remove_unused_levels")
reorder_levels = _unsupported_function("reorder_levels")
searchsorted = _unsupported_function("searchsorted")
set_codes = _unsupported_function("set_codes")
set_levels = _unsupported_function("set_levels")
slice_indexer = _unsupported_function("slice_indexer")
slice_locs = _unsupported_function("slice_locs")
sortlevel = _unsupported_function("sortlevel")
to_flat_index = _unsupported_function("to_flat_index")
to_native_types = _unsupported_function("to_native_types")
truncate = _unsupported_function("truncate")
where = _unsupported_function("where")
# Deprecated functions
is_mixed = _unsupported_function("is_mixed")
get_duplicates = _unsupported_function("get_duplicates", deprecated=True)
get_values = _unsupported_function("get_values", deprecated=True)
set_value = _unsupported_function("set_value", deprecated=True)
# Functions we won't support.
array = common.array(_unsupported_property)
duplicated = common.duplicated(_unsupported_property)
codes = _unsupported_property(
"codes",
reason="'codes' requires to collect all data into the driver which is against the "
"design principle of Koalas. Alternatively, you could call 'to_pandas()' and"
" use 'codes' property in pandas.",
)
levels = _unsupported_property(
"levels",
reason="'levels' requires to collect all data into the driver which is against the "
"design principle of Koalas. Alternatively, you could call 'to_pandas()' and"
" use 'levels' property in pandas.",
)
__iter__ = common.__iter__(_unsupported_function)
# Properties we won't support.
memory_usage = common.memory_usage(_unsupported_function)
if LooseVersion(pd.__version__) < LooseVersion("1.0"):
# Deprecated properties
base = _unsupported_property("base", deprecated=True)
labels = _unsupported_property("labels", deprecated=True)
flags = _unsupported_property("flags", deprecated=True)
# Deprecated functions
set_labels = _unsupported_function("set_labels")
summary = _unsupported_function("summary", deprecated=True)
to_hierarchical = _unsupported_function("to_hierarchical", deprecated=True)
contains = _unsupported_function("contains", deprecated=True)

View file

@ -0,0 +1,125 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion
import pandas as pd
from pyspark.pandas.missing import unsupported_function, unsupported_property, common
def _unsupported_function(method_name, deprecated=False, reason=""):
return unsupported_function(
class_name="pd.Series", method_name=method_name, deprecated=deprecated, reason=reason
)
def _unsupported_property(property_name, deprecated=False, reason=""):
return unsupported_property(
class_name="pd.Series", property_name=property_name, deprecated=deprecated, reason=reason
)
class MissingPandasLikeSeries(object):
# Functions
asfreq = _unsupported_function("asfreq")
autocorr = _unsupported_function("autocorr")
combine = _unsupported_function("combine")
convert_dtypes = _unsupported_function("convert_dtypes")
cov = _unsupported_function("cov")
ewm = _unsupported_function("ewm")
infer_objects = _unsupported_function("infer_objects")
interpolate = _unsupported_function("interpolate")
reorder_levels = _unsupported_function("reorder_levels")
resample = _unsupported_function("resample")
searchsorted = _unsupported_function("searchsorted")
set_axis = _unsupported_function("set_axis")
slice_shift = _unsupported_function("slice_shift")
to_hdf = _unsupported_function("to_hdf")
to_period = _unsupported_function("to_period")
to_sql = _unsupported_function("to_sql")
to_timestamp = _unsupported_function("to_timestamp")
tshift = _unsupported_function("tshift")
tz_convert = _unsupported_function("tz_convert")
tz_localize = _unsupported_function("tz_localize")
view = _unsupported_function("view")
# Deprecated functions
convert_objects = _unsupported_function("convert_objects", deprecated=True)
nonzero = _unsupported_function("nonzero", deprecated=True)
reindex_axis = _unsupported_function("reindex_axis", deprecated=True)
select = _unsupported_function("select", deprecated=True)
get_values = _unsupported_function("get_values", deprecated=True)
# Properties we won't support.
array = common.array(_unsupported_property)
duplicated = common.duplicated(_unsupported_property)
nbytes = _unsupported_property(
"nbytes",
reason="'nbytes' requires to compute whole dataset. You can calculate manually it, "
"with its 'itemsize', by explicitly executing its count. Use Spark's web UI "
"to monitor disk and memory usage of your application in general.",
)
# Functions we won't support.
memory_usage = common.memory_usage(_unsupported_function)
to_pickle = common.to_pickle(_unsupported_function)
to_xarray = common.to_xarray(_unsupported_function)
__iter__ = common.__iter__(_unsupported_function)
ravel = _unsupported_function(
"ravel",
reason="If you want to collect your flattened underlying data as an NumPy array, "
"use 'to_numpy().ravel()' instead.",
)
if LooseVersion(pd.__version__) < LooseVersion("1.0"):
# Deprecated properties
blocks = _unsupported_property("blocks", deprecated=True)
ftypes = _unsupported_property("ftypes", deprecated=True)
ftype = _unsupported_property("ftype", deprecated=True)
is_copy = _unsupported_property("is_copy", deprecated=True)
ix = _unsupported_property("ix", deprecated=True)
asobject = _unsupported_property("asobject", deprecated=True)
strides = _unsupported_property("strides", deprecated=True)
imag = _unsupported_property("imag", deprecated=True)
itemsize = _unsupported_property("itemsize", deprecated=True)
data = _unsupported_property("data", deprecated=True)
base = _unsupported_property("base", deprecated=True)
flags = _unsupported_property("flags", deprecated=True)
# Deprecated functions
as_blocks = _unsupported_function("as_blocks", deprecated=True)
as_matrix = _unsupported_function("as_matrix", deprecated=True)
clip_lower = _unsupported_function("clip_lower", deprecated=True)
clip_upper = _unsupported_function("clip_upper", deprecated=True)
compress = _unsupported_function("compress", deprecated=True)
get_ftype_counts = _unsupported_function("get_ftype_counts", deprecated=True)
get_value = _unsupported_function("get_value", deprecated=True)
set_value = _unsupported_function("set_value", deprecated=True)
valid = _unsupported_function("valid", deprecated=True)
to_dense = _unsupported_function("to_dense", deprecated=True)
to_sparse = _unsupported_function("to_sparse", deprecated=True)
to_msgpack = _unsupported_function("to_msgpack", deprecated=True)
compound = _unsupported_function("compound", deprecated=True)
put = _unsupported_function("put", deprecated=True)
ptp = _unsupported_function("ptp", deprecated=True)
# Functions we won't support.
real = _unsupported_property(
"real",
reason="If you want to collect your data as an NumPy array, use 'to_numpy()' instead.",
)

View file

@ -0,0 +1,126 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.missing import unsupported_function, unsupported_property
def _unsupported_function_expanding(method_name, deprecated=False, reason=""):
return unsupported_function(
class_name="pandas.core.window.Expanding",
method_name=method_name,
deprecated=deprecated,
reason=reason,
)
def _unsupported_property_expanding(property_name, deprecated=False, reason=""):
return unsupported_property(
class_name="pandas.core.window.Expanding",
property_name=property_name,
deprecated=deprecated,
reason=reason,
)
def _unsupported_function_rolling(method_name, deprecated=False, reason=""):
return unsupported_function(
class_name="pandas.core.window.Rolling",
method_name=method_name,
deprecated=deprecated,
reason=reason,
)
def _unsupported_property_rolling(property_name, deprecated=False, reason=""):
return unsupported_property(
class_name="pandas.core.window.Rolling",
property_name=property_name,
deprecated=deprecated,
reason=reason,
)
class MissingPandasLikeExpanding(object):
agg = _unsupported_function_expanding("agg")
aggregate = _unsupported_function_expanding("aggregate")
apply = _unsupported_function_expanding("apply")
corr = _unsupported_function_expanding("corr")
cov = _unsupported_function_expanding("cov")
kurt = _unsupported_function_expanding("kurt")
median = _unsupported_function_expanding("median")
quantile = _unsupported_function_expanding("quantile")
skew = _unsupported_function_expanding("skew")
validate = _unsupported_function_expanding("validate")
exclusions = _unsupported_property_expanding("exclusions")
is_datetimelike = _unsupported_property_expanding("is_datetimelike")
is_freq_type = _unsupported_property_expanding("is_freq_type")
ndim = _unsupported_property_expanding("ndim")
class MissingPandasLikeRolling(object):
agg = _unsupported_function_rolling("agg")
aggregate = _unsupported_function_rolling("aggregate")
apply = _unsupported_function_rolling("apply")
corr = _unsupported_function_rolling("corr")
cov = _unsupported_function_rolling("cov")
kurt = _unsupported_function_rolling("kurt")
median = _unsupported_function_rolling("median")
quantile = _unsupported_function_rolling("quantile")
skew = _unsupported_function_rolling("skew")
validate = _unsupported_function_rolling("validate")
exclusions = _unsupported_property_rolling("exclusions")
is_datetimelike = _unsupported_property_rolling("is_datetimelike")
is_freq_type = _unsupported_property_rolling("is_freq_type")
ndim = _unsupported_property_rolling("ndim")
class MissingPandasLikeExpandingGroupby(object):
agg = _unsupported_function_expanding("agg")
aggregate = _unsupported_function_expanding("aggregate")
apply = _unsupported_function_expanding("apply")
corr = _unsupported_function_expanding("corr")
cov = _unsupported_function_expanding("cov")
kurt = _unsupported_function_expanding("kurt")
median = _unsupported_function_expanding("median")
quantile = _unsupported_function_expanding("quantile")
skew = _unsupported_function_expanding("skew")
validate = _unsupported_function_expanding("validate")
exclusions = _unsupported_property_expanding("exclusions")
is_datetimelike = _unsupported_property_expanding("is_datetimelike")
is_freq_type = _unsupported_property_expanding("is_freq_type")
ndim = _unsupported_property_expanding("ndim")
class MissingPandasLikeRollingGroupby(object):
agg = _unsupported_function_rolling("agg")
aggregate = _unsupported_function_rolling("aggregate")
apply = _unsupported_function_rolling("apply")
corr = _unsupported_function_rolling("corr")
cov = _unsupported_function_rolling("cov")
kurt = _unsupported_function_rolling("kurt")
median = _unsupported_function_rolling("median")
quantile = _unsupported_function_rolling("quantile")
skew = _unsupported_function_rolling("skew")
validate = _unsupported_function_rolling("validate")
exclusions = _unsupported_property_rolling("exclusions")
is_datetimelike = _unsupported_property_rolling("is_datetimelike")
is_freq_type = _unsupported_property_rolling("is_freq_type")
ndim = _unsupported_property_rolling("ndim")

View file

@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Tuple, TYPE_CHECKING
import numpy as np
import pandas as pd
import pyspark
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.pandas.utils import column_labels_level
if TYPE_CHECKING:
import pyspark.pandas as pp # noqa: F401 (SPARK-34943)
CORRELATION_OUTPUT_COLUMN = "__correlation_output__"
def corr(kdf: "pp.DataFrame", method: str = "pearson") -> pd.DataFrame:
"""
The correlation matrix of all the numerical columns of this dataframe.
Only accepts scalar numerical values for now.
:param kdf: the Koalas dataframe.
:param method: {'pearson', 'spearman'}
* pearson : standard correlation coefficient
* spearman : Spearman rank correlation
:return: :class:`pandas.DataFrame`
>>> pp.DataFrame({'A': [0, 1], 'B': [1, 0], 'C': ['x', 'y']}).corr()
A B
A 1.0 -1.0
B -1.0 1.0
"""
assert method in ("pearson", "spearman")
ndf, column_labels = to_numeric_df(kdf)
corr = Correlation.corr(ndf, CORRELATION_OUTPUT_COLUMN, method)
pcorr = corr.toPandas()
arr = pcorr.iloc[0, 0].toArray()
if column_labels_level(column_labels) > 1:
idx = pd.MultiIndex.from_tuples(column_labels)
else:
idx = pd.Index([label[0] for label in column_labels])
return pd.DataFrame(arr, columns=idx, index=idx)
def to_numeric_df(kdf: "pp.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Tuple]]:
"""
Takes a dataframe and turns it into a dataframe containing a single numerical
vector of doubles. This dataframe has a single field called '_1'.
TODO: index is not preserved currently
:param kdf: the Koalas dataframe.
:return: a pair of dataframe, list of strings (the name of the columns
that were converted to numerical types)
>>> to_numeric_df(pp.DataFrame({'A': [0, 1], 'B': [1, 0], 'C': ['x', 'y']}))
(DataFrame[__correlation_output__: vector], [('A',), ('B',)])
"""
# TODO, it should be more robust.
accepted_types = {
np.dtype(dt)
for dt in [np.int8, np.int16, np.int32, np.int64, np.float32, np.float64, np.bool_]
}
numeric_column_labels = [
label for label in kdf._internal.column_labels if kdf[label].dtype in accepted_types
]
numeric_df = kdf._internal.spark_frame.select(
*[kdf._internal.spark_column_for(idx) for idx in numeric_column_labels]
)
va = VectorAssembler(inputCols=numeric_df.columns, outputCol=CORRELATION_OUTPUT_COLUMN)
v = va.transform(numeric_df).select(CORRELATION_OUTPUT_COLUMN)
return v, numeric_column_labels

View file

@ -0,0 +1,192 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
MLflow-related functions to load models and apply them to Koalas dataframes.
"""
from mlflow import pyfunc
from pyspark.sql.types import DataType
import pandas as pd
import numpy as np
from typing import Any
from pyspark.pandas.utils import lazy_property, default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import first_series
from pyspark.pandas.typedef import as_spark_type
__all__ = ["PythonModelWrapper", "load_model"]
class PythonModelWrapper(object):
"""
A wrapper around MLflow's Python object model.
This wrapper acts as a predictor on koalas
"""
def __init__(self, model_uri, return_type_hint):
self._model_uri = model_uri # type: str
self._return_type_hint = return_type_hint
@lazy_property
def _return_type(self) -> DataType:
hint = self._return_type_hint
# The logic is simple for now, because it corresponds to the default
# case: continuous predictions
# TODO: do something smarter, for example when there is a sklearn.Classifier (it should
# return an integer or a categorical)
# We can do the same for pytorch/tensorflow/keras models by looking at the output types.
# However, this is probably better done in mlflow than here.
if hint == "infer" or not hint:
hint = np.float64
return as_spark_type(hint)
@lazy_property
def _model(self) -> Any:
"""
The return object has to follow the API of mlflow.pyfunc.PythonModel.
"""
return pyfunc.load_model(model_uri=self._model_uri)
@lazy_property
def _model_udf(self):
spark = default_session()
return pyfunc.spark_udf(spark, model_uri=self._model_uri, result_type=self._return_type)
def __str__(self):
return "PythonModelWrapper({})".format(str(self._model))
def __repr__(self):
return "PythonModelWrapper({})".format(repr(self._model))
def predict(self, data):
"""
Returns a prediction on the data.
If the data is a koalas DataFrame, the return is a Koalas Series.
If the data is a pandas Dataframe, the return is the expected output of the underlying
pyfunc object (typically a pandas Series or a numpy array).
"""
if isinstance(data, pd.DataFrame):
return self._model.predict(data)
if isinstance(data, DataFrame):
return_col = self._model_udf(*data._internal.data_spark_columns)
# TODO: the columns should be named according to the mlflow spec
# However, this is only possible with spark >= 3.0
# s = F.struct(*data.columns)
# return_col = self._model_udf(s)
column_labels = [
(col,) for col in data._internal.spark_frame.select(return_col).columns
]
internal = data._internal.copy(
column_labels=column_labels, data_spark_columns=[return_col], data_dtypes=None
)
return first_series(DataFrame(internal))
def load_model(model_uri, predict_type="infer") -> PythonModelWrapper:
"""
Loads an MLflow model into an wrapper that can be used both for pandas and Koalas DataFrame.
Parameters
----------
model_uri : str
URI pointing to the model. See MLflow documentation for more details.
predict_type : a python basic type, a numpy basic type, a Spark type or 'infer'.
This is the return type that is expected when calling the predict function of the model.
If 'infer' is specified, the wrapper will attempt to determine automatically the return type
based on the model type.
Returns
-------
PythonModelWrapper
A wrapper around MLflow PythonModel objects. This wrapper is expected to adhere to the
interface of mlflow.pyfunc.PythonModel.
Examples
--------
Here is a full example that creates a model with scikit-learn and saves the model with
MLflow. The model is then loaded as a predictor that can be applied on a Koalas
Dataframe.
We first initialize our MLflow environment:
>>> from mlflow.tracking import MlflowClient, set_tracking_uri
>>> import mlflow.sklearn
>>> from tempfile import mkdtemp
>>> d = mkdtemp("koalas_mlflow")
>>> set_tracking_uri("file:%s"%d)
>>> client = MlflowClient()
>>> exp = mlflow.create_experiment("my_experiment")
>>> mlflow.set_experiment("my_experiment")
We aim at learning this numerical function using a simple linear regressor.
>>> from sklearn.linear_model import LinearRegression
>>> train = pd.DataFrame({"x1": np.arange(8), "x2": np.arange(8)**2,
... "y": np.log(2 + np.arange(8))})
>>> train_x = train[["x1", "x2"]]
>>> train_y = train[["y"]]
>>> with mlflow.start_run():
... lr = LinearRegression()
... lr.fit(train_x, train_y)
... mlflow.sklearn.log_model(lr, "model")
LinearRegression(...)
Now that our model is logged using MLflow, we load it back and apply it on a Koalas dataframe:
>>> from pyspark.pandas.mlflow import load_model
>>> run_info = client.list_run_infos(exp)[-1]
>>> model = load_model("runs:/{run_id}/model".format(run_id=run_info.run_uuid))
>>> prediction_df = pp.DataFrame({"x1": [2.0], "x2": [4.0]})
>>> prediction_df["prediction"] = model.predict(prediction_df)
>>> prediction_df
x1 x2 prediction
0 2.0 4.0 1.355551
The model also works on pandas DataFrames as expected:
>>> model.predict(prediction_df[["x1", "x2"]].to_pandas())
array([[1.35555142]])
Notes
-----
Currently, the model prediction can only be merged back with the existing dataframe.
Other columns have to be manually joined.
For example, this code will not work:
>>> df = pp.DataFrame({"x1": [2.0], "x2": [3.0], "z": [-1]})
>>> features = df[["x1", "x2"]]
>>> y = model.predict(features)
>>> # Works:
>>> features["y"] = y # doctest: +SKIP
>>> # Will fail with a message about dataframes not aligned.
>>> df["y"] = y # doctest: +SKIP
A current workaround is to use the .merge() function, using the feature values
as merging keys.
>>> features['y'] = y
>>> everything = df.merge(features, on=['x1', 'x2'])
>>> everything
x1 x2 z y
0 2.0 3.0 -1 1.376932
"""
return PythonModelWrapper(model_uri, predict_type)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,210 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict
from typing import Callable, Any
import numpy as np
from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType, LongType, BooleanType
unary_np_spark_mappings = OrderedDict(
{
"abs": F.abs,
"absolute": F.abs,
"arccos": F.acos,
"arccosh": F.pandas_udf(lambda s: np.arccosh(s), DoubleType()),
"arcsin": F.asin,
"arcsinh": F.pandas_udf(lambda s: np.arcsinh(s), DoubleType()),
"arctan": F.atan,
"arctanh": F.pandas_udf(lambda s: np.arctanh(s), DoubleType()),
"bitwise_not": F.bitwiseNOT,
"cbrt": F.cbrt,
"ceil": F.ceil,
# It requires complex type which Koalas does not support yet
"conj": lambda _: NotImplemented,
"conjugate": lambda _: NotImplemented, # It requires complex type
"cos": F.cos,
"cosh": F.pandas_udf(lambda s: np.cosh(s), DoubleType()),
"deg2rad": F.pandas_udf(lambda s: np.deg2rad(s), DoubleType()),
"degrees": F.degrees,
"exp": F.exp,
"exp2": F.pandas_udf(lambda s: np.exp2(s), DoubleType()),
"expm1": F.expm1,
"fabs": F.pandas_udf(lambda s: np.fabs(s), DoubleType()),
"floor": F.floor,
"frexp": lambda _: NotImplemented, # 'frexp' output lengths become different
# and it cannot be supported via pandas UDF.
"invert": F.pandas_udf(lambda s: np.invert(s), DoubleType()),
"isfinite": lambda c: c != float("inf"),
"isinf": lambda c: c == float("inf"),
"isnan": F.isnan,
"isnat": lambda c: NotImplemented, # Koalas and PySpark does not have Nat concept.
"log": F.log,
"log10": F.log10,
"log1p": F.log1p,
"log2": F.pandas_udf(lambda s: np.log2(s), DoubleType()),
"logical_not": lambda c: ~(c.cast(BooleanType())),
"matmul": lambda _: NotImplemented, # Can return a NumPy array in pandas.
"negative": lambda c: c * -1,
"positive": lambda c: c,
"rad2deg": F.pandas_udf(lambda s: np.rad2deg(s), DoubleType()),
"radians": F.radians,
"reciprocal": F.pandas_udf(lambda s: np.reciprocal(s), DoubleType()),
"rint": F.pandas_udf(lambda s: np.rint(s), DoubleType()),
"sign": lambda c: F.when(c == 0, 0).when(c < 0, -1).otherwise(1),
"signbit": lambda c: F.when(c < 0, True).otherwise(False),
"sin": F.sin,
"sinh": F.pandas_udf(lambda s: np.sinh(s), DoubleType()),
"spacing": F.pandas_udf(lambda s: np.spacing(s), DoubleType()),
"sqrt": F.sqrt,
"square": F.pandas_udf(lambda s: np.square(s), DoubleType()),
"tan": F.tan,
"tanh": F.pandas_udf(lambda s: np.tanh(s), DoubleType()),
"trunc": F.pandas_udf(lambda s: np.trunc(s), DoubleType()),
}
)
binary_np_spark_mappings = OrderedDict(
{
"arctan2": F.atan2,
"bitwise_and": lambda c1, c2: c1.bitwiseAND(c2),
"bitwise_or": lambda c1, c2: c1.bitwiseOR(c2),
"bitwise_xor": lambda c1, c2: c1.bitwiseXOR(c2),
"copysign": F.pandas_udf(lambda s1, s2: np.copysign(s1, s2), DoubleType()),
"float_power": F.pandas_udf(lambda s1, s2: np.float_power(s1, s2), DoubleType()),
"floor_divide": F.pandas_udf(lambda s1, s2: np.floor_divide(s1, s2), DoubleType()),
"fmax": F.pandas_udf(lambda s1, s2: np.fmax(s1, s2), DoubleType()),
"fmin": F.pandas_udf(lambda s1, s2: np.fmin(s1, s2), DoubleType()),
"fmod": F.pandas_udf(lambda s1, s2: np.fmod(s1, s2), DoubleType()),
"gcd": F.pandas_udf(lambda s1, s2: np.gcd(s1, s2), DoubleType()),
"heaviside": F.pandas_udf(lambda s1, s2: np.heaviside(s1, s2), DoubleType()),
"hypot": F.hypot,
"lcm": F.pandas_udf(lambda s1, s2: np.lcm(s1, s2), DoubleType()),
"ldexp": F.pandas_udf(lambda s1, s2: np.ldexp(s1, s2), DoubleType()),
"left_shift": F.pandas_udf(lambda s1, s2: np.left_shift(s1, s2), LongType()),
"logaddexp": F.pandas_udf(lambda s1, s2: np.logaddexp(s1, s2), DoubleType()),
"logaddexp2": F.pandas_udf(lambda s1, s2: np.logaddexp2(s1, s2), DoubleType()),
"logical_and": lambda c1, c2: c1.cast(BooleanType()) & c2.cast(BooleanType()),
"logical_or": lambda c1, c2: c1.cast(BooleanType()) | c2.cast(BooleanType()),
"logical_xor": lambda c1, c2: (
# mimics xor by logical operators.
(c1.cast(BooleanType()) | c2.cast(BooleanType()))
& (~(c1.cast(BooleanType())) | ~(c2.cast(BooleanType())))
),
"maximum": F.greatest,
"minimum": F.least,
"modf": F.pandas_udf(lambda s1, s2: np.modf(s1, s2), DoubleType()),
"nextafter": F.pandas_udf(lambda s1, s2: np.nextafter(s1, s2), DoubleType()),
"right_shift": F.pandas_udf(lambda s1, s2: np.right_shift(s1, s2), LongType()),
}
)
# Copied from pandas.
# See also https://docs.scipy.org/doc/numpy/reference/arrays.classes.html#standard-array-subclasses
def maybe_dispatch_ufunc_to_dunder_op(
ser_or_index, ufunc: Callable, method: str, *inputs, **kwargs: Any
):
special = {
"add",
"sub",
"mul",
"pow",
"mod",
"floordiv",
"truediv",
"divmod",
"eq",
"ne",
"lt",
"gt",
"le",
"ge",
"remainder",
"matmul",
}
aliases = {
"absolute": "abs",
"multiply": "mul",
"floor_divide": "floordiv",
"true_divide": "truediv",
"power": "pow",
"remainder": "mod",
"divide": "div",
"equal": "eq",
"not_equal": "ne",
"less": "lt",
"less_equal": "le",
"greater": "gt",
"greater_equal": "ge",
}
# For op(., Array) -> Array.__r{op}__
flipped = {
"lt": "__gt__",
"le": "__ge__",
"gt": "__lt__",
"ge": "__le__",
"eq": "__eq__",
"ne": "__ne__",
}
op_name = ufunc.__name__
op_name = aliases.get(op_name, op_name)
def not_implemented(*args, **kwargs):
return NotImplemented
if method == "__call__" and op_name in special and kwargs.get("out") is None:
if isinstance(inputs[0], type(ser_or_index)):
name = "__{}__".format(op_name)
return getattr(ser_or_index, name, not_implemented)(inputs[1])
else:
name = flipped.get(op_name, "__r{}__".format(op_name))
return getattr(ser_or_index, name, not_implemented)(inputs[0])
else:
return NotImplemented
# See also https://docs.scipy.org/doc/numpy/reference/arrays.classes.html#standard-array-subclasses
def maybe_dispatch_ufunc_to_spark_func(
ser_or_index, ufunc: Callable, method: str, *inputs, **kwargs: Any
):
from pyspark.pandas.base import column_op
op_name = ufunc.__name__
if (
method == "__call__"
and (op_name in unary_np_spark_mappings or op_name in binary_np_spark_mappings)
and kwargs.get("out") is None
):
np_spark_map_func = unary_np_spark_mappings.get(op_name) or binary_np_spark_mappings.get(
op_name
)
def convert_arguments(*args):
args = [ # type: ignore
F.lit(inp) if not isinstance(inp, Column) else inp for inp in args
] # type: ignore
return np_spark_map_func(*args)
return column_op(convert_arguments)(*inputs) # type: ignore
else:
return NotImplemented

View file

@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.plot.core import * # noqa: F401

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,897 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion
import matplotlib as mat
import numpy as np
import pandas as pd
from matplotlib.axes._base import _process_plot_format
from pandas.core.dtypes.inference import is_list_like
from pandas.io.formats.printing import pprint_thing
from pyspark.pandas.plot import (
TopNPlotBase,
SampledPlotBase,
HistogramPlotBase,
BoxPlotBase,
unsupported_function,
KdePlotBase,
)
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
from pandas.plotting._core import (
_all_kinds,
BarPlot as PandasBarPlot,
BoxPlot as PandasBoxPlot,
HistPlot as PandasHistPlot,
MPLPlot as PandasMPLPlot,
PiePlot as PandasPiePlot,
AreaPlot as PandasAreaPlot,
LinePlot as PandasLinePlot,
BarhPlot as PandasBarhPlot,
ScatterPlot as PandasScatterPlot,
KdePlot as PandasKdePlot,
)
else:
from pandas.plotting._matplotlib import (
BarPlot as PandasBarPlot,
BoxPlot as PandasBoxPlot,
HistPlot as PandasHistPlot,
PiePlot as PandasPiePlot,
AreaPlot as PandasAreaPlot,
LinePlot as PandasLinePlot,
BarhPlot as PandasBarhPlot,
ScatterPlot as PandasScatterPlot,
KdePlot as PandasKdePlot,
)
from pandas.plotting._core import PlotAccessor
from pandas.plotting._matplotlib.core import MPLPlot as PandasMPLPlot
_all_kinds = PlotAccessor._all_kinds
class KoalasBarPlot(PandasBarPlot, TopNPlotBase):
def __init__(self, data, **kwargs):
super().__init__(self.get_top_n(data), **kwargs)
def _plot(self, ax, x, y, w, start=0, log=False, **kwds):
self.set_result_text(ax)
return ax.bar(x, y, w, bottom=start, log=log, **kwds)
class KoalasBoxPlot(PandasBoxPlot, BoxPlotBase):
def boxplot(
self,
ax,
bxpstats,
notch=None,
sym=None,
vert=None,
whis=None,
positions=None,
widths=None,
patch_artist=None,
bootstrap=None,
usermedians=None,
conf_intervals=None,
meanline=None,
showmeans=None,
showcaps=None,
showbox=None,
showfliers=None,
boxprops=None,
labels=None,
flierprops=None,
medianprops=None,
meanprops=None,
capprops=None,
whiskerprops=None,
manage_ticks=None,
# manage_xticks is for compatibility of matplotlib < 3.1.0.
# Remove this when minimum version is 3.0.0
manage_xticks=None,
autorange=False,
zorder=None,
precision=None,
):
def update_dict(dictionary, rc_name, properties):
""" Loads properties in the dictionary from rc file if not already
in the dictionary"""
rc_str = "boxplot.{0}.{1}"
if dictionary is None:
dictionary = dict()
for prop_dict in properties:
dictionary.setdefault(prop_dict, mat.rcParams[rc_str.format(rc_name, prop_dict)])
return dictionary
# Common property dictionaries loading from rc
flier_props = [
"color",
"marker",
"markerfacecolor",
"markeredgecolor",
"markersize",
"linestyle",
"linewidth",
]
default_props = ["color", "linewidth", "linestyle"]
boxprops = update_dict(boxprops, "boxprops", default_props)
whiskerprops = update_dict(whiskerprops, "whiskerprops", default_props)
capprops = update_dict(capprops, "capprops", default_props)
medianprops = update_dict(medianprops, "medianprops", default_props)
meanprops = update_dict(meanprops, "meanprops", default_props)
flierprops = update_dict(flierprops, "flierprops", flier_props)
if patch_artist:
boxprops["linestyle"] = "solid"
boxprops["edgecolor"] = boxprops.pop("color")
# if non-default sym value, put it into the flier dictionary
# the logic for providing the default symbol ('b+') now lives
# in bxp in the initial value of final_flierprops
# handle all of the `sym` related logic here so we only have to pass
# on the flierprops dict.
if sym is not None:
# no-flier case, which should really be done with
# 'showfliers=False' but none-the-less deal with it to keep back
# compatibility
if sym == "":
# blow away existing dict and make one for invisible markers
flierprops = dict(linestyle="none", marker="", color="none")
# turn the fliers off just to be safe
showfliers = False
# now process the symbol string
else:
# process the symbol string
# discarded linestyle
_, marker, color = _process_plot_format(sym)
# if we have a marker, use it
if marker is not None:
flierprops["marker"] = marker
# if we have a color, use it
if color is not None:
# assume that if color is passed in the user want
# filled symbol, if the users want more control use
# flierprops
flierprops["color"] = color
flierprops["markerfacecolor"] = color
flierprops["markeredgecolor"] = color
# replace medians if necessary:
if usermedians is not None:
if len(np.ravel(usermedians)) != len(bxpstats) or np.shape(usermedians)[0] != len(
bxpstats
):
raise ValueError("usermedians length not compatible with x")
else:
# reassign medians as necessary
for stats, med in zip(bxpstats, usermedians):
if med is not None:
stats["med"] = med
if conf_intervals is not None:
if np.shape(conf_intervals)[0] != len(bxpstats):
err_mess = "conf_intervals length not compatible with x"
raise ValueError(err_mess)
else:
for stats, ci in zip(bxpstats, conf_intervals):
if ci is not None:
if len(ci) != 2:
raise ValueError("each confidence interval must " "have two values")
else:
if ci[0] is not None:
stats["cilo"] = ci[0]
if ci[1] is not None:
stats["cihi"] = ci[1]
should_manage_ticks = True
if manage_xticks is not None:
should_manage_ticks = manage_xticks
if manage_ticks is not None:
should_manage_ticks = manage_ticks
if LooseVersion(mat.__version__) < LooseVersion("3.1.0"):
extra_args = {"manage_xticks": should_manage_ticks}
else:
extra_args = {"manage_ticks": should_manage_ticks}
artists = ax.bxp(
bxpstats,
positions=positions,
widths=widths,
vert=vert,
patch_artist=patch_artist,
shownotches=notch,
showmeans=showmeans,
showcaps=showcaps,
showbox=showbox,
boxprops=boxprops,
flierprops=flierprops,
medianprops=medianprops,
meanprops=meanprops,
meanline=meanline,
showfliers=showfliers,
capprops=capprops,
whiskerprops=whiskerprops,
zorder=zorder,
**extra_args,
)
return artists
def _plot(self, ax, bxpstats, column_num=None, return_type="axes", **kwds):
bp = self.boxplot(ax, bxpstats, **kwds)
if return_type == "dict":
return bp, bp
elif return_type == "both":
return self.BP(ax=ax, lines=bp), bp
else:
return ax, bp
def _compute_plot_data(self):
colname = self.data.name
spark_column_name = self.data._internal.spark_column_name_for(self.data._column_label)
data = self.data
# Updates all props with the rc defaults from matplotlib
self.kwds.update(KoalasBoxPlot.rc_defaults(**self.kwds))
# Gets some important kwds
showfliers = self.kwds.get("showfliers", False)
whis = self.kwds.get("whis", 1.5)
labels = self.kwds.get("labels", [colname])
# This one is Koalas specific to control precision for approx_percentile
precision = self.kwds.get("precision", 0.01)
# # Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)
# # Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
# # Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
if showfliers:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])
else:
fliers = []
# Builds bxpstats dict
stats = []
item = {
"mean": col_stats["mean"],
"med": col_stats["med"],
"q1": col_stats["q1"],
"q3": col_stats["q3"],
"whislo": whiskers[0],
"whishi": whiskers[1],
"fliers": fliers,
"label": labels[0],
}
stats.append(item)
self.data = {labels[0]: stats}
def _make_plot(self):
bxpstats = list(self.data.values())[0]
ax = self._get_ax(0)
kwds = self.kwds.copy()
for stats in bxpstats:
if len(stats["fliers"]) > 1000:
stats["fliers"] = stats["fliers"][:1000]
ax.text(
1,
1,
"showing top 1,000 fliers only",
size=6,
ha="right",
va="bottom",
transform=ax.transAxes,
)
ret, bp = self._plot(ax, bxpstats, column_num=0, return_type=self.return_type, **kwds)
self.maybe_color_bp(bp)
self._return_obj = ret
labels = [l for l, _ in self.data.items()]
labels = [pprint_thing(l) for l in labels]
if not self.use_index:
labels = [pprint_thing(key) for key in range(len(labels))]
self._set_ticklabels(ax, labels)
@staticmethod
def rc_defaults(
notch=None,
vert=None,
whis=None,
patch_artist=None,
bootstrap=None,
meanline=None,
showmeans=None,
showcaps=None,
showbox=None,
showfliers=None,
**kwargs
):
# Missing arguments default to rcParams.
if whis is None:
whis = mat.rcParams["boxplot.whiskers"]
if bootstrap is None:
bootstrap = mat.rcParams["boxplot.bootstrap"]
if notch is None:
notch = mat.rcParams["boxplot.notch"]
if vert is None:
vert = mat.rcParams["boxplot.vertical"]
if patch_artist is None:
patch_artist = mat.rcParams["boxplot.patchartist"]
if meanline is None:
meanline = mat.rcParams["boxplot.meanline"]
if showmeans is None:
showmeans = mat.rcParams["boxplot.showmeans"]
if showcaps is None:
showcaps = mat.rcParams["boxplot.showcaps"]
if showbox is None:
showbox = mat.rcParams["boxplot.showbox"]
if showfliers is None:
showfliers = mat.rcParams["boxplot.showfliers"]
return dict(
whis=whis,
bootstrap=bootstrap,
notch=notch,
vert=vert,
patch_artist=patch_artist,
meanline=meanline,
showmeans=showmeans,
showcaps=showcaps,
showbox=showbox,
showfliers=showfliers,
)
class KoalasHistPlot(PandasHistPlot, HistogramPlotBase):
def _args_adjust(self):
if is_list_like(self.bottom):
self.bottom = np.array(self.bottom)
def _compute_plot_data(self):
self.data, self.bins = HistogramPlotBase.prepare_hist_data(self.data, self.bins)
def _make_plot(self):
# TODO: this logic is similar with KdePlot. Might have to deduplicate it.
# 'num_colors' requires to calculate `shape` which has to count all.
# Use 1 for now to save the computation.
colors = self._get_colors(num_colors=1)
stacking_id = self._get_stacking_id()
output_series = HistogramPlotBase.compute_hist(self.data, self.bins)
for (i, label), y in zip(enumerate(self.data._internal.column_labels), output_series):
ax = self._get_ax(i)
kwds = self.kwds.copy()
label = pprint_thing(label if len(label) > 1 else label[0])
kwds["label"] = label
style, kwds = self._apply_style_colors(colors, kwds, i, label)
if style is not None:
kwds["style"] = style
kwds = self._make_plot_keywords(kwds, y)
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
self._add_legend_handle(artists[0], label, index=i)
@classmethod
def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0, stacking_id=None, **kwds):
if column_num == 0:
cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
base = np.zeros(len(bins) - 1)
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
# Since the counts were computed already, we use them as weights and just generate
# one entry for each bin
n, bins, patches = ax.hist(bins[:-1], bins=bins, bottom=bottom, weights=y, **kwds)
cls._update_stacker(ax, stacking_id, n)
return patches
class KoalasPiePlot(PandasPiePlot, TopNPlotBase):
def __init__(self, data, **kwargs):
super().__init__(self.get_top_n(data), **kwargs)
def _make_plot(self):
self.set_result_text(self._get_ax(0))
super()._make_plot()
class KoalasAreaPlot(PandasAreaPlot, SampledPlotBase):
def __init__(self, data, **kwargs):
super().__init__(self.get_sampled(data), **kwargs)
def _make_plot(self):
self.set_result_text(self._get_ax(0))
super()._make_plot()
class KoalasLinePlot(PandasLinePlot, SampledPlotBase):
def __init__(self, data, **kwargs):
super().__init__(self.get_sampled(data), **kwargs)
def _make_plot(self):
self.set_result_text(self._get_ax(0))
super()._make_plot()
class KoalasBarhPlot(PandasBarhPlot, TopNPlotBase):
def __init__(self, data, **kwargs):
super().__init__(self.get_top_n(data), **kwargs)
def _make_plot(self):
self.set_result_text(self._get_ax(0))
super()._make_plot()
class KoalasScatterPlot(PandasScatterPlot, TopNPlotBase):
def __init__(self, data, x, y, **kwargs):
super().__init__(self.get_top_n(data), x, y, **kwargs)
def _make_plot(self):
self.set_result_text(self._get_ax(0))
super()._make_plot()
class KoalasKdePlot(PandasKdePlot, KdePlotBase):
def _compute_plot_data(self):
self.data = KdePlotBase.prepare_kde_data(self.data)
def _make_plot(self):
# 'num_colors' requires to calculate `shape` which has to count all.
# Use 1 for now to save the computation.
colors = self._get_colors(num_colors=1)
stacking_id = self._get_stacking_id()
sdf = self.data._internal.spark_frame
for i, label in enumerate(self.data._internal.column_labels):
# 'y' is a Spark DataFrame that selects one column.
y = sdf.select(self.data._internal.spark_column_for(label))
ax = self._get_ax(i)
kwds = self.kwds.copy()
label = pprint_thing(label if len(label) > 1 else label[0])
kwds["label"] = label
style, kwds = self._apply_style_colors(colors, kwds, i, label)
if style is not None:
kwds["style"] = style
kwds = self._make_plot_keywords(kwds, y)
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
self._add_legend_handle(artists[0], label, index=i)
def _get_ind(self, y):
return KdePlotBase.get_ind(y, self.ind)
@classmethod
def _plot(
cls, ax, y, style=None, bw_method=None, ind=None, column_num=None, stacking_id=None, **kwds
):
y = KdePlotBase.compute_kde(y, bw_method=bw_method, ind=ind)
lines = PandasMPLPlot._plot(ax, ind, y, style=style, **kwds)
return lines
_klasses = [
KoalasHistPlot,
KoalasBarPlot,
KoalasBoxPlot,
KoalasPiePlot,
KoalasAreaPlot,
KoalasLinePlot,
KoalasBarhPlot,
KoalasScatterPlot,
KoalasKdePlot,
]
_plot_klass = {getattr(klass, "_kind"): klass for klass in _klasses}
_common_kinds = {"area", "bar", "barh", "box", "hist", "kde", "line", "pie"}
_series_kinds = _common_kinds.union(set())
_dataframe_kinds = _common_kinds.union({"scatter", "hexbin"})
_koalas_all_kinds = _common_kinds.union(_series_kinds).union(_dataframe_kinds)
def plot_koalas(data, kind, **kwargs):
if kind not in _koalas_all_kinds:
raise ValueError("{} is not a valid plot kind".format(kind))
from pyspark.pandas import DataFrame, Series
if isinstance(data, Series):
if kind not in _series_kinds:
return unsupported_function(class_name="pd.Series", method_name=kind)()
return plot_series(data=data, kind=kind, **kwargs)
elif isinstance(data, DataFrame):
if kind not in _dataframe_kinds:
return unsupported_function(class_name="pd.DataFrame", method_name=kind)()
return plot_frame(data=data, kind=kind, **kwargs)
def plot_series(
data,
kind="line",
ax=None, # Series unique
figsize=None,
use_index=True,
title=None,
grid=None,
legend=False,
style=None,
logx=False,
logy=False,
loglog=False,
xticks=None,
yticks=None,
xlim=None,
ylim=None,
rot=None,
fontsize=None,
colormap=None,
table=False,
yerr=None,
xerr=None,
label=None,
secondary_y=False, # Series unique
**kwds
):
"""
Make plots of Series using matplotlib / pylab.
Each plot kind has a corresponding method on the
``Series.plot`` accessor:
``s.plot(kind='line')`` is equivalent to
``s.plot.line()``.
Parameters
----------
data : Series
kind : str
- 'line' : line plot (default)
- 'bar' : vertical bar plot
- 'barh' : horizontal bar plot
- 'hist' : histogram
- 'box' : boxplot
- 'kde' : Kernel Density Estimation plot
- 'density' : same as 'kde'
- 'area' : area plot
- 'pie' : pie plot
ax : matplotlib axes object
If not passed, uses gca()
figsize : a tuple (width, height) in inches
use_index : boolean, default True
Use index as ticks for x axis
title : string or list
Title to use for the plot. If a string is passed, print the string at
the top of the figure. If a list is passed and `subplots` is True,
print each item in the list above the corresponding subplot.
grid : boolean, default None (matlab style default)
Axis grid lines
legend : False/True/'reverse'
Place legend on axis subplots
style : list or dict
matplotlib line style per column
logx : boolean, default False
Use log scaling on x axis
logy : boolean, default False
Use log scaling on y axis
loglog : boolean, default False
Use log scaling on both x and y axes
xticks : sequence
Values to use for the xticks
yticks : sequence
Values to use for the yticks
xlim : 2-tuple/list
ylim : 2-tuple/list
rot : int, default None
Rotation for ticks (xticks for vertical, yticks for horizontal plots)
fontsize : int, default None
Font size for xticks and yticks
colormap : str or matplotlib colormap object, default None
Colormap to select colors from. If string, load colormap with that name
from matplotlib.
colorbar : boolean, optional
If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
position : float
Specify relative alignments for bar plot layout.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
table : boolean, Series or DataFrame, default False
If True, draw a table using the data in the DataFrame and the data will
be transposed to meet matplotlib's default layout.
If a Series or DataFrame is passed, use passed data to draw a table.
yerr : DataFrame, Series, array-like, dict and str
See :ref:`Plotting with Error Bars <visualization.errorbars>` for
detail.
xerr : same types as yerr.
label : label argument to provide to plot
secondary_y : boolean or sequence of ints, default False
If True then y-axis will be on the right
mark_right : boolean, default True
When using a secondary_y axis, automatically mark the column
labels with "(right)" in the legend
**kwds : keywords
Options to pass to matplotlib plotting method
Returns
-------
axes : :class:`matplotlib.axes.Axes` or numpy.ndarray of them
Notes
-----
- See matplotlib documentation online for more on this subject
- If `kind` = 'bar' or 'barh', you can specify relative alignments
for bar plot layout by `position` keyword.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
"""
# function copied from pandas.plotting._core
# so it calls modified _plot below
import matplotlib.pyplot as plt
if ax is None and len(plt.get_fignums()) > 0:
with plt.rc_context():
ax = plt.gca()
ax = PandasMPLPlot._get_ax_layer(ax)
return _plot(
data,
kind=kind,
ax=ax,
figsize=figsize,
use_index=use_index,
title=title,
grid=grid,
legend=legend,
style=style,
logx=logx,
logy=logy,
loglog=loglog,
xticks=xticks,
yticks=yticks,
xlim=xlim,
ylim=ylim,
rot=rot,
fontsize=fontsize,
colormap=colormap,
table=table,
yerr=yerr,
xerr=xerr,
label=label,
secondary_y=secondary_y,
**kwds,
)
def plot_frame(
data,
x=None,
y=None,
kind="line",
ax=None,
subplots=None,
sharex=None,
sharey=False,
layout=None,
figsize=None,
use_index=True,
title=None,
grid=None,
legend=True,
style=None,
logx=False,
logy=False,
loglog=False,
xticks=None,
yticks=None,
xlim=None,
ylim=None,
rot=None,
fontsize=None,
colormap=None,
table=False,
yerr=None,
xerr=None,
secondary_y=False,
sort_columns=False,
**kwds
):
"""
Make plots of DataFrames using matplotlib / pylab.
Each plot kind has a corresponding method on the
``DataFrame.plot`` accessor:
``kdf.plot(kind='line')`` is equivalent to
``kdf.plot.line()``.
Parameters
----------
data : DataFrame
kind : str
- 'line' : line plot (default)
- 'bar' : vertical bar plot
- 'barh' : horizontal bar plot
- 'hist' : histogram
- 'box' : boxplot
- 'kde' : Kernel Density Estimation plot
- 'density' : same as 'kde'
- 'area' : area plot
- 'pie' : pie plot
- 'scatter' : scatter plot
ax : matplotlib axes object
If not passed, uses gca()
x : label or position, default None
y : label, position or list of label, positions, default None
Allows plotting of one column versus another.
figsize : a tuple (width, height) in inches
use_index : boolean, default True
Use index as ticks for x axis
title : string or list
Title to use for the plot. If a string is passed, print the string at
the top of the figure. If a list is passed and `subplots` is True,
print each item in the list above the corresponding subplot.
grid : boolean, default None (matlab style default)
Axis grid lines
legend : False/True/'reverse'
Place legend on axis subplots
style : list or dict
matplotlib line style per column
logx : boolean, default False
Use log scaling on x axis
logy : boolean, default False
Use log scaling on y axis
loglog : boolean, default False
Use log scaling on both x and y axes
xticks : sequence
Values to use for the xticks
yticks : sequence
Values to use for the yticks
xlim : 2-tuple/list
ylim : 2-tuple/list
sharex: bool or None, default is None
Whether to share x axis or not.
sharey: bool, default is False
Whether to share y axis or not.
rot : int, default None
Rotation for ticks (xticks for vertical, yticks for horizontal plots)
fontsize : int, default None
Font size for xticks and yticks
colormap : str or matplotlib colormap object, default None
Colormap to select colors from. If string, load colormap with that name
from matplotlib.
colorbar : boolean, optional
If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
position : float
Specify relative alignments for bar plot layout.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
table : boolean, Series or DataFrame, default False
If True, draw a table using the data in the DataFrame and the data will
be transposed to meet matplotlib's default layout.
If a Series or DataFrame is passed, use passed data to draw a table.
yerr : DataFrame, Series, array-like, dict and str
See :ref:`Plotting with Error Bars <visualization.errorbars>` for
detail.
xerr : same types as yerr.
label : label argument to provide to plot
secondary_y : boolean or sequence of ints, default False
If True then y-axis will be on the right
mark_right : boolean, default True
When using a secondary_y axis, automatically mark the column
labels with "(right)" in the legend
sort_columns: bool, default is False
When True, will sort values on plots.
**kwds : keywords
Options to pass to matplotlib plotting method
Returns
-------
axes : :class:`matplotlib.axes.Axes` or numpy.ndarray of them
Notes
-----
- See matplotlib documentation online for more on this subject
- If `kind` = 'bar' or 'barh', you can specify relative alignments
for bar plot layout by `position` keyword.
From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
"""
return _plot(
data,
kind=kind,
x=x,
y=y,
ax=ax,
figsize=figsize,
use_index=use_index,
title=title,
grid=grid,
legend=legend,
subplots=subplots,
style=style,
logx=logx,
logy=logy,
loglog=loglog,
xticks=xticks,
yticks=yticks,
xlim=xlim,
ylim=ylim,
rot=rot,
fontsize=fontsize,
colormap=colormap,
table=table,
yerr=yerr,
xerr=xerr,
sharex=sharex,
sharey=sharey,
secondary_y=secondary_y,
layout=layout,
sort_columns=sort_columns,
**kwds,
)
def _plot(data, x=None, y=None, subplots=False, ax=None, kind="line", **kwds):
from pyspark.pandas import DataFrame
# function copied from pandas.plotting._core
# and adapted to handle Koalas DataFrame and Series
kind = kind.lower().strip()
kind = {"density": "kde"}.get(kind, kind)
if kind in _all_kinds:
klass = _plot_klass[kind]
else:
raise ValueError("%r is not a valid plot kind" % kind)
# scatter and hexbin are inherited from PlanePlot which require x and y
if kind in ("scatter", "hexbin"):
plot_obj = klass(data, x, y, subplots=subplots, ax=ax, kind=kind, **kwds)
else:
# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj.generate()
plot_obj.draw()
return plot_obj.result

View file

@ -0,0 +1,212 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Union
import pandas as pd
from pyspark.pandas.plot import (
HistogramPlotBase,
name_like_string,
KoalasPlotAccessor,
BoxPlotBase,
KdePlotBase,
)
if TYPE_CHECKING:
import pyspark.pandas as pp # noqa: F401 (SPARK-34943)
def plot_koalas(data: Union["pp.DataFrame", "pp.Series"], kind: str, **kwargs):
import plotly
# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
return plot_histogram(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)
# Other plots.
return plotly.plot(KoalasPlotAccessor.pandas_plot_data_map[kind](data), kind, **kwargs)
def plot_pie(data: Union["pp.DataFrame", "pp.Series"], **kwargs):
from plotly import express
data = KoalasPlotAccessor.pandas_plot_data_map["pie"](data)
if isinstance(data, pd.Series):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
values = kwargs.pop("y", None)
default_names = None
if values is not None:
default_names = data.index
return express.pie(
data,
values=kwargs.pop("values", values),
names=kwargs.pop("names", default_names),
**kwargs,
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))
def plot_histogram(data: Union["pp.DataFrame", "pp.Series"], **kwargs):
import plotly.graph_objs as go
bins = kwargs.get("bins", 10)
kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = HistogramPlotBase.compute_hist(kdf, bins)
prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
text_bins = []
for b in bins[1:]:
norm_b = float("%.9f" % b)
text_bins.append("[%s, %s)" % (prev, norm_b))
prev = norm_b
text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket.
bins = 0.5 * (bins[:-1] + bins[1:])
output_series = list(output_series)
bars = []
for series in output_series:
bars.append(
go.Bar(
x=bins,
y=series,
name=name_like_string(series.name),
text=text_bins,
hovertemplate=(
"variable=" + name_like_string(series.name) + "<br>value=%{text}<br>count=%{y}"
),
)
)
fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig
def plot_box(data: Union["pp.DataFrame", "pp.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as pp
if isinstance(data, pp.DataFrame):
raise RuntimeError(
"plotly does not support a box plot with Koalas DataFrame. Use Series instead."
)
# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
# third quartiles (?). Looks they use default 1.5.
whis = kwargs.pop("whis", 1.5)
# 'precision' is Koalas specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)
# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise ValueError(
"plotly plotting backend does not support 'boxpoints' set to '%s'. "
"Set to 'suspectedoutliers' or False." % boxpoints
)
if notched:
raise ValueError(
"plotly plotting backend does not support 'notched' set to '%s'. "
"Set to False." % notched
)
colname = name_like_string(data.name)
spark_column_name = data._internal.spark_column_name_for(data._column_label)
# Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)
# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
fliers = None
if boxpoints:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])
fliers = [fliers] if len(fliers) > 0 else None
fig = go.Figure()
fig.add_trace(
go.Box(
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[whiskers[0]],
upperfence=[whiskers[1]],
y=fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname
fig["layout"]["yaxis"]["title"] = "value"
return fig
def plot_kde(data: Union["pp.DataFrame", "pp.Series"], **kwargs):
from plotly import express
import pyspark.pandas as pp
if isinstance(data, pp.DataFrame) and "color" not in kwargs:
kwargs["color"] = "names"
kdf = KdePlotBase.prepare_kde_data(data)
sdf = kdf._internal.spark_frame
data_columns = kdf._internal.data_spark_columns
ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None))
bw_method = kwargs.pop("bw_method", None)
pdfs = []
for label in kdf._internal.column_labels:
pdfs.append(
pd.DataFrame(
{
"Density": KdePlotBase.compute_kde(
sdf.select(kdf._internal.spark_column_for(label)),
ind=ind,
bw_method=bw_method,
),
"names": name_like_string(label),
"index": ind,
}
)
)
pdf = pd.concat(pdfs)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,98 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Additional Spark functions used in Koalas.
"""
from pyspark import SparkContext
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
__all__ = ["percentile_approx"]
def percentile_approx(col, percentage, accuracy=10000):
"""
Returns the approximate percentile value of numeric column col at the given percentage.
The value of percentage must be between 0.0 and 1.0.
The accuracy parameter (default: 10000)
is a positive numeric literal which controls approximation accuracy at the cost of memory.
Higher value of accuracy yields better accuracy, 1.0/accuracy is the relative error
of the approximation.
When percentage is an array, each value of the percentage array must be between 0.0 and 1.0.
In this case, returns the approximate percentile array of column col
at the given percentage array.
Ported from Spark 3.1.
"""
sc = SparkContext._active_spark_context
if isinstance(percentage, (list, tuple)):
# A local list
percentage = sc._jvm.functions.array(
_to_seq(sc, [_create_column_from_literal(x) for x in percentage])
)
elif isinstance(percentage, Column):
# Already a Column
percentage = _to_java_column(percentage)
else:
# Probably scalar
percentage = _create_column_from_literal(percentage)
accuracy = (
_to_java_column(accuracy)
if isinstance(accuracy, Column)
else _create_column_from_literal(accuracy)
)
return _call_udf(sc, "percentile_approx", _to_java_column(col), percentage, accuracy)
def array_repeat(col, count):
"""
Collection function: creates an array containing a column repeated count times.
Ported from Spark 3.0.
"""
sc = SparkContext._active_spark_context
return Column(
sc._jvm.functions.array_repeat(
_to_java_column(col), _to_java_column(count) if isinstance(count, Column) else count
)
)
def repeat(col, n):
"""
Repeats a string column n times, and returns it as a new string column.
"""
sc = SparkContext._active_spark_context
n = _to_java_column(n) if isinstance(n, Column) else _create_column_from_literal(n)
return _call_udf(sc, "repeat", _to_java_column(col), n)
def _call_udf(sc, name, *cols):
return Column(sc._jvm.functions.callUDF(name, _make_arguments(sc, *cols)))
def _make_arguments(sc, *cols):
java_arr = sc._gateway.new_array(sc._jvm.Column, len(cols))
for i, col in enumerate(cols):
java_arr[i] = col
return java_arr

View file

@ -0,0 +1,124 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Helpers and utilities to deal with PySpark instances
"""
from pyspark.sql.types import DecimalType, StructType, MapType, ArrayType, StructField, DataType
def as_nullable_spark_type(dt: DataType) -> DataType:
"""
Returns a nullable schema or data types.
Examples
--------
>>> from pyspark.sql.types import *
>>> as_nullable_spark_type(StructType([
... StructField("A", IntegerType(), True),
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,IntegerType,true),StructField(B,FloatType,true)))
>>> as_nullable_spark_type(StructType([
... StructField("A",
... StructType([
... StructField('a',
... MapType(IntegerType(),
... ArrayType(IntegerType(), False), False), False),
... StructField('b', StringType(), True)])),
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,StructType(List(StructField(a,MapType(IntegerType,ArrayType\
(IntegerType,true),true),true),StructField(b,StringType,true))),true),\
StructField(B,FloatType,true)))
"""
if isinstance(dt, StructType):
new_fields = []
for field in dt.fields:
new_fields.append(
StructField(
field.name,
as_nullable_spark_type(field.dataType),
nullable=True,
metadata=field.metadata,
)
)
return StructType(new_fields)
elif isinstance(dt, ArrayType):
return ArrayType(as_nullable_spark_type(dt.elementType), containsNull=True)
elif isinstance(dt, MapType):
return MapType(
as_nullable_spark_type(dt.keyType),
as_nullable_spark_type(dt.valueType),
valueContainsNull=True,
)
else:
return dt
def force_decimal_precision_scale(dt: DataType, precision: int = 38, scale: int = 18) -> DataType:
"""
Returns a data type with a fixed decimal type.
The precision and scale of the decimal type are fixed with the given values.
Examples
--------
>>> from pyspark.sql.types import *
>>> force_decimal_precision_scale(StructType([
... StructField("A", DecimalType(10, 0), True),
... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,DecimalType(38,18),true),StructField(B,DecimalType(38,18),false)))
>>> force_decimal_precision_scale(StructType([
... StructField("A",
... StructType([
... StructField('a',
... MapType(DecimalType(5, 0),
... ArrayType(DecimalType(20, 0), False), False), False),
... StructField('b', StringType(), True)])),
... StructField("B", DecimalType(30, 15), False)]),
... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,StructType(List(StructField(a,MapType(DecimalType(30,15),\
ArrayType(DecimalType(30,15),false),false),false),StructField(b,StringType,true))),true),\
StructField(B,DecimalType(30,15),false)))
"""
if isinstance(dt, StructType):
new_fields = []
for field in dt.fields:
new_fields.append(
StructField(
field.name,
force_decimal_precision_scale(field.dataType, precision, scale),
nullable=field.nullable,
metadata=field.metadata,
)
)
return StructType(new_fields)
elif isinstance(dt, ArrayType):
return ArrayType(
force_decimal_precision_scale(dt.elementType, precision, scale),
containsNull=dt.containsNull,
)
elif isinstance(dt, MapType):
return MapType(
force_decimal_precision_scale(dt.keyType, precision, scale),
force_decimal_precision_scale(dt.valueType, precision, scale),
valueContainsNull=dt.valueContainsNull,
)
elif isinstance(dt, DecimalType):
return DecimalType(precision=precision, scale=scale)
else:
return dt

View file

@ -0,0 +1,302 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import _string
from typing import Dict, Any, Optional # noqa: F401 (SPARK-34943)
import inspect
import pandas as pd
from pyspark.sql import SparkSession, DataFrame as SDataFrame # noqa: F401 (SPARK-34943)
from pyspark import pandas as pp # For running doctests and reference resolution in PyCharm.
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
__all__ = ["sql"]
from builtins import globals as builtin_globals
from builtins import locals as builtin_locals
def sql(query: str, globals=None, locals=None, **kwargs) -> DataFrame:
"""
Execute a SQL query and return the result as a Koalas DataFrame.
This function also supports embedding Python variables (locals, globals, and parameters)
in the SQL statement by wrapping them in curly braces. See examples section for details.
In addition to the locals, globals and parameters, the function will also attempt
to determine if the program currently runs in an IPython (or Jupyter) environment
and to import the variables from this environment. The variables have the same
precedence as globals.
The following variable types are supported:
* string
* int
* float
* list, tuple, range of above types
* Koalas DataFrame
* Koalas Series
* pandas DataFrame
Parameters
----------
query : str
the SQL query
globals : dict, optional
the dictionary of global variables, if explicitly set by the user
locals : dict, optional
the dictionary of local variables, if explicitly set by the user
kwargs
other variables that the user may want to set manually that can be referenced in the query
Returns
-------
Koalas DataFrame
Examples
--------
Calling a built-in SQL function.
>>> pp.sql("select * from range(10) where id > 7")
id
0 8
1 9
A query can also reference a local variable or parameter by wrapping them in curly braces:
>>> bound1 = 7
>>> pp.sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9)
id
0 8
You can also wrap a DataFrame with curly braces to query it directly. Note that when you do
that, the indexes, if any, automatically become top level columns.
>>> mydf = pp.range(10)
>>> x = range(4)
>>> pp.sql("SELECT * from {mydf} WHERE id IN {x}")
id
0 0
1 1
2 2
3 3
Queries can also be arbitrarily nested in functions:
>>> def statement():
... mydf2 = pp.DataFrame({"x": range(2)})
... return pp.sql("SELECT * from {mydf2}")
>>> statement()
x
0 0
1 1
Mixing Koalas and pandas DataFrames in a join operation. Note that the index is dropped.
>>> pp.sql('''
... SELECT m1.a, m2.b
... FROM {table1} m1 INNER JOIN {table2} m2
... ON m1.key = m2.key
... ORDER BY m1.a, m2.b''',
... table1=pp.DataFrame({"a": [1,2], "key": ["a", "b"]}),
... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
a b
0 1 3
1 2 4
2 2 5
Also, it is possible to query using Series.
>>> myser = pp.Series({'a': [1.0, 2.0, 3.0], 'b': [15.0, 30.0, 45.0]})
>>> pp.sql("SELECT * from {myser}")
0
0 [1.0, 2.0, 3.0]
1 [15.0, 30.0, 45.0]
"""
if globals is None:
globals = _get_ipython_scope()
_globals = builtin_globals() if globals is None else dict(globals)
_locals = builtin_locals() if locals is None else dict(locals)
# The default choice is the globals
_dict = dict(_globals)
# The vars:
_scope = _get_local_scope()
_dict.update(_scope)
# Then the locals
_dict.update(_locals)
# Highest order of precedence is the locals
_dict.update(kwargs)
return SQLProcessor(_dict, query, default_session()).execute()
_CAPTURE_SCOPES = 2
def _get_local_scope():
# Get 2 scopes above (_get_local_scope -> sql -> ...) to capture the vars there.
try:
return inspect.stack()[_CAPTURE_SCOPES][0].f_locals
except Exception as e:
# TODO (rxin, thunterdb): use a more narrow scope exception.
# See https://github.com/pyspark.pandas/pull/448
return {}
def _get_ipython_scope():
"""
Tries to extract the dictionary of variables if the program is running
in an IPython notebook environment.
"""
try:
from IPython import get_ipython
shell = get_ipython()
return shell.user_ns
except Exception as e:
# TODO (rxin, thunterdb): use a more narrow scope exception.
# See https://github.com/pyspark.pandas/pull/448
return None
# Originally from pymysql package
_escape_table = [chr(x) for x in range(128)]
_escape_table[0] = "\\0"
_escape_table[ord("\\")] = "\\\\"
_escape_table[ord("\n")] = "\\n"
_escape_table[ord("\r")] = "\\r"
_escape_table[ord("\032")] = "\\Z"
_escape_table[ord('"')] = '\\"'
_escape_table[ord("'")] = "\\'"
def escape_sql_string(value: str) -> str:
"""Escapes value without adding quotes.
>>> escape_sql_string("foo\\nbar")
'foo\\\\nbar'
>>> escape_sql_string("'abc'de")
"\\\\'abc\\\\'de"
>>> escape_sql_string('"abc"de')
'\\\\"abc\\\\"de'
"""
return value.translate(_escape_table)
class SQLProcessor(object):
def __init__(self, scope: Dict[str, Any], statement: str, session: SparkSession):
self._scope = scope
self._statement = statement
# All the temporary views created when executing this statement
# The key is the name of the variable in {}
# The value is the cached Spark Dataframe.
self._temp_views = {} # type: Dict[str, SDataFrame]
# All the other variables, converted to a normalized form.
# The normalized form is typically a string
self._cached_vars = {} # type: Dict[str, Any]
# The SQL statement after:
# - all the dataframes have been have been registered as temporary views
# - all the values have been converted normalized to equivalent SQL representations
self._normalized_statement = None # type: Optional[str]
self._session = session
def execute(self) -> DataFrame:
"""
Returns a DataFrame for which the SQL statement has been executed by
the underlying SQL engine.
>>> str0 = 'abc'
>>> pp.sql("select {str0}")
abc
0 abc
>>> str1 = 'abc"abc'
>>> str2 = "abc'abc"
>>> pp.sql("select {str0}, {str1}, {str2}")
abc abc"abc abc'abc
0 abc abc"abc abc'abc
>>> strs = ['a', 'b']
>>> pp.sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2")
cond1 cond2
0 True False
"""
blocks = _string.formatter_parser(self._statement)
# TODO: use a string builder
res = ""
try:
for (pre, inner, _, _) in blocks:
var_next = "" if inner is None else self._convert(inner)
res = res + pre + var_next
self._normalized_statement = res
sdf = self._session.sql(self._normalized_statement)
finally:
for v in self._temp_views:
self._session.catalog.dropTempView(v)
return DataFrame(sdf)
def _convert(self, key) -> Any:
"""
Given a {} key, returns an equivalent SQL representation.
This conversion performs all the necessary escaping so that the string
returned can be directly injected into the SQL statement.
"""
# Already cached?
if key in self._cached_vars:
return self._cached_vars[key]
# Analyze:
if key not in self._scope:
raise ValueError(
"The key {} in the SQL statement was not found in global,"
" local or parameters variables".format(key)
)
var = self._scope[key]
fillin = self._convert_var(var)
self._cached_vars[key] = fillin
return fillin
def _convert_var(self, var) -> Any:
"""
Converts a python object into a string that is legal SQL.
"""
if isinstance(var, (int, float)):
return str(var)
if isinstance(var, Series):
return self._convert_var(var.to_dataframe())
if isinstance(var, pd.DataFrame):
return self._convert_var(pp.DataFrame(var))
if isinstance(var, DataFrame):
df_id = "koalas_" + str(id(var))
if df_id not in self._temp_views:
sdf = var.to_spark()
sdf.createOrReplaceTempView(df_id)
self._temp_views[df_id] = sdf
return df_id
if isinstance(var, str):
return '"' + escape_sql_string(var) + '"'
if isinstance(var, list):
return "(" + ", ".join([self._convert_var(v) for v in var]) + ")"
if isinstance(var, (tuple, range)):
return self._convert_var(list(var))
raise ValueError("Unsupported variable type {}: {}".format(type(var).__name__, str(var)))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,18 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.pandas.typedef.typehints import * # noqa: F401,F405

View file

@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np # noqa: F401
import pandas # noqa: F401
import pandas as pd # noqa: F401
from numpy import * # noqa: F401
from pandas import * # noqa: F401
from inspect import getfullargspec # noqa: F401
def resolve_string_type_hint(tpe):
import pyspark.pandas as pp
from pyspark.pandas import DataFrame, Series
locs = {
"pp": pp,
"koalas": pp,
"DataFrame": DataFrame,
"Series": Series,
}
# This is a hack to resolve the forward reference string.
exec("def func() -> %s: pass\narg_spec = getfullargspec(func)" % tpe, globals(), locs)
return locs["arg_spec"].annotations.get("return", None)

View file

@ -0,0 +1,521 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utilities to deal with types. This is mostly focused on python3.
"""
import datetime
import decimal
from inspect import getfullargspec, isclass
from typing import Generic, List, Optional, Tuple, TypeVar, Union # noqa: F401
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype, pandas_dtype
from pandas.api.extensions import ExtensionDtype
try:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
extension_dtypes_available = True
extension_dtypes = (Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype) # type: Tuple
try:
from pandas import BooleanDtype, StringDtype
extension_object_dtypes_available = True
extension_dtypes += (BooleanDtype, StringDtype)
except ImportError:
extension_object_dtypes_available = False
try:
from pandas import Float32Dtype, Float64Dtype
extension_float_dtypes_available = True
extension_dtypes += (Float32Dtype, Float64Dtype)
except ImportError:
extension_float_dtypes_available = False
except ImportError:
extension_dtypes_available = False
extension_object_dtypes_available = False
extension_float_dtypes_available = False
extension_dtypes = ()
import pyarrow as pa
import pyspark.sql.types as types
try:
from pyspark.sql.types import to_arrow_type, from_arrow_type
except ImportError:
from pyspark.sql.pandas.types import to_arrow_type, from_arrow_type
from pyspark import pandas as pp # For running doctests and reference resolution in PyCharm.
from pyspark.pandas.typedef.string_typehints import resolve_string_type_hint
T = TypeVar("T")
Scalar = Union[
int, float, bool, str, bytes, decimal.Decimal, datetime.date, datetime.datetime, None
]
Dtype = Union[np.dtype, ExtensionDtype]
# A column of data, with the data type.
class SeriesType(Generic[T]):
def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type
def __repr__(self):
return "SeriesType[{}]".format(self.spark_type)
class DataFrameType(object):
def __init__(
self, dtypes: List[Dtype], spark_types: List[types.DataType], names: List[Optional[str]]
):
from pyspark.pandas.utils import name_like_string
self.dtypes = dtypes
self.spark_type = types.StructType(
[
types.StructField(name_like_string(n) if n is not None else ("c%s" % i), t)
for i, (n, t) in enumerate(zip(names, spark_types))
]
) # type: types.StructType
def __repr__(self):
return "DataFrameType[{}]".format(self.spark_type)
# The type is a scalar type that is furthermore understood by Spark.
class ScalarType(object):
def __init__(self, dtype: Dtype, spark_type: types.DataType):
self.dtype = dtype
self.spark_type = spark_type
def __repr__(self):
return "ScalarType[{}]".format(self.spark_type)
# The type is left unspecified or we do not know about this type.
class UnknownType(object):
def __init__(self, tpe):
self.tpe = tpe
def __repr__(self):
return "UnknownType[{}]".format(self.tpe)
class NameTypeHolder(object):
name = None
tpe = None
def as_spark_type(tpe: Union[str, type, Dtype], *, raise_error: bool = True) -> types.DataType:
"""
Given a Python type, returns the equivalent spark type.
Accepts:
- the built-in types in Python
- the built-in types in numpy
- list of pairs of (field_name, type)
- dictionaries of field_name -> type
- Python3's typing system
"""
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
pass
# ArrayType
elif tpe in (np.ndarray,):
return types.ArrayType(types.StringType())
elif hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, list): # type: ignore
element_type = as_spark_type(tpe.__args__[0], raise_error=raise_error) # type: ignore
if element_type is None:
return None
return types.ArrayType(element_type)
# BinaryType
elif tpe in (bytes, np.character, np.bytes_, np.string_):
return types.BinaryType()
# BooleanType
elif tpe in (bool, np.bool, "bool", "?"):
return types.BooleanType()
# DateType
elif tpe in (datetime.date,):
return types.DateType()
# NumericType
elif tpe in (np.int8, np.byte, "int8", "byte", "b"):
return types.ByteType()
elif tpe in (decimal.Decimal,):
# TODO: considering about the precision & scale for decimal type.
return types.DecimalType(38, 18)
elif tpe in (float, np.float, np.float64, "float", "float64", "double"):
return types.DoubleType()
elif tpe in (np.float32, "float32", "f"):
return types.FloatType()
elif tpe in (np.int32, "int32", "i"):
return types.IntegerType()
elif tpe in (int, np.int, np.int64, "int", "int64", "long"):
return types.LongType()
elif tpe in (np.int16, "int16", "short"):
return types.ShortType()
# StringType
elif tpe in (str, np.unicode_, "str", "U"):
return types.StringType()
# TimestampType
elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"):
return types.TimestampType()
# categorical types
elif isinstance(tpe, CategoricalDtype) or (isinstance(tpe, str) and type == "category"):
return types.LongType()
# extension types
elif extension_dtypes_available:
# IntegralType
if isinstance(tpe, Int8Dtype) or (isinstance(tpe, str) and tpe == "Int8"):
return types.ByteType()
elif isinstance(tpe, Int16Dtype) or (isinstance(tpe, str) and tpe == "Int16"):
return types.ShortType()
elif isinstance(tpe, Int32Dtype) or (isinstance(tpe, str) and tpe == "Int32"):
return types.IntegerType()
elif isinstance(tpe, Int64Dtype) or (isinstance(tpe, str) and tpe == "Int64"):
return types.LongType()
if extension_object_dtypes_available:
# BooleanType
if isinstance(tpe, BooleanDtype) or (isinstance(tpe, str) and tpe == "boolean"):
return types.BooleanType()
# StringType
elif isinstance(tpe, StringDtype) or (isinstance(tpe, str) and tpe == "string"):
return types.StringType()
if extension_float_dtypes_available:
# FractionalType
if isinstance(tpe, Float32Dtype) or (isinstance(tpe, str) and tpe == "Float32"):
return types.FloatType()
elif isinstance(tpe, Float64Dtype) or (isinstance(tpe, str) and tpe == "Float64"):
return types.DoubleType()
if raise_error:
raise TypeError("Type %s was not understood." % tpe)
else:
return None
def spark_type_to_pandas_dtype(
spark_type: types.DataType, *, use_extension_dtypes: bool = False
) -> Dtype:
""" Return the given Spark DataType to pandas dtype. """
if use_extension_dtypes and extension_dtypes_available:
# IntegralType
if isinstance(spark_type, types.ByteType):
return Int8Dtype()
elif isinstance(spark_type, types.ShortType):
return Int16Dtype()
elif isinstance(spark_type, types.IntegerType):
return Int32Dtype()
elif isinstance(spark_type, types.LongType):
return Int64Dtype()
if extension_object_dtypes_available:
# BooleanType
if isinstance(spark_type, types.BooleanType):
return BooleanDtype()
# StringType
elif isinstance(spark_type, types.StringType):
return StringDtype()
# FractionalType
if extension_float_dtypes_available:
if isinstance(spark_type, types.FloatType):
return Float32Dtype()
elif isinstance(spark_type, types.DoubleType):
return Float64Dtype()
if isinstance(
spark_type,
(
types.DateType,
types.NullType,
types.ArrayType,
types.MapType,
types.StructType,
types.UserDefinedType,
),
):
return np.dtype("object")
elif isinstance(spark_type, types.TimestampType):
return np.dtype("datetime64[ns]")
else:
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
def koalas_dtype(tpe) -> Tuple[Dtype, types.DataType]:
"""
Convert input into a pandas only dtype object or a numpy dtype object,
and its corresponding Spark DataType.
Parameters
----------
tpe : object to be converted
Returns
-------
tuple of np.dtype or a pandas dtype, and Spark DataType
Raises
------
TypeError if not a dtype
Examples
--------
>>> koalas_dtype(int)
(dtype('int64'), LongType)
>>> koalas_dtype(str)
(dtype('<U'), StringType)
>>> koalas_dtype(datetime.date)
(dtype('O'), DateType)
>>> koalas_dtype(datetime.datetime)
(dtype('<M8[ns]'), TimestampType)
>>> koalas_dtype(List[bool])
(dtype('O'), ArrayType(BooleanType,true))
"""
try:
dtype = pandas_dtype(tpe)
spark_type = as_spark_type(dtype)
except TypeError:
spark_type = as_spark_type(tpe)
dtype = spark_type_to_pandas_dtype(spark_type)
return dtype, spark_type
def infer_pd_series_spark_type(pser: pd.Series, dtype: Dtype) -> types.DataType:
"""Infer Spark DataType from pandas Series dtype.
:param pser: :class:`pandas.Series` to be inferred
:param dtype: the Series' dtype
:return: the inferred Spark data type
"""
if dtype == np.dtype("object"):
if len(pser) == 0 or pser.isnull().all():
return types.NullType()
elif hasattr(pser.iloc[0], "__UDT__"):
return pser.iloc[0].__UDT__
else:
return from_arrow_type(pa.Array.from_pandas(pser).type)
elif isinstance(dtype, CategoricalDtype):
# `pser` must already be converted to codes.
return as_spark_type(pser.dtype)
else:
return as_spark_type(dtype)
def infer_return_type(f) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]:
"""
Infer the return type from the return type annotation of the given function.
The returned type class indicates both dtypes (a pandas only dtype object
or a numpy dtype object) and its corresponding Spark DataType.
>>> def func() -> int:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> pp.Series[int]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> pp.DataFrame[np.float, str]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('<U')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true)))
>>> def func() -> pp.DataFrame[np.float]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true)))
>>> def func() -> 'int':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> 'pp.Series[int]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
dtype('int64')
>>> inferred.spark_type
LongType
>>> def func() -> 'pp.DataFrame[np.float, str]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('<U')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true)))
>>> def func() -> 'pp.DataFrame[np.float]':
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64')]
>>> inferred.spark_type
StructType(List(StructField(c0,DoubleType,true)))
>>> def func() -> pp.DataFrame['a': np.float, 'b': int]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true)))
>>> def func() -> "pp.DataFrame['a': np.float, 'b': int]":
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('float64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
>>> def func() -> pp.DataFrame[pdf.dtypes]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
>>> def func() -> pp.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField(a,LongType,true),StructField(b,LongType,true)))
>>> pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]})
>>> def func() -> pp.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), dtype('int64')]
>>> inferred.spark_type
StructType(List(StructField((x, a),LongType,true),StructField((y, b),LongType,true)))
>>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical([3, 4, 5])})
>>> def func() -> pp.DataFrame[pdf.dtypes]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type
StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true)))
>>> def func() -> pp.DataFrame[zip(pdf.columns, pdf.dtypes)]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtypes
[dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type
StructType(List(StructField(a,LongType,true),StructField(b,LongType,true)))
>>> def func() -> pp.Series[pdf.b.dtype]:
... pass
>>> inferred = infer_return_type(func)
>>> inferred.dtype
CategoricalDtype(categories=[3, 4, 5], ordered=False)
>>> inferred.spark_type
LongType
"""
# We should re-import to make sure the class 'SeriesType' is not treated as a class
# within this module locally. See Series.__class_getitem__ which imports this class
# canonically.
from pyspark.pandas.typedef import SeriesType, NameTypeHolder
spec = getfullargspec(f)
tpe = spec.annotations.get("return", None)
if isinstance(tpe, str):
# This type hint can happen when given hints are string to avoid forward reference.
tpe = resolve_string_type_hint(tpe)
if hasattr(tpe, "__origin__") and (
tpe.__origin__ == pp.DataFrame or tpe.__origin__ == pp.Series
):
# When Python version is lower then 3.7. Unwrap it to a Tuple/SeriesType type hints.
tpe = tpe.__args__[0]
if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType):
tpe = tpe.__args__[0]
if issubclass(tpe, NameTypeHolder):
tpe = tpe.tpe
dtype, spark_type = koalas_dtype(tpe)
return SeriesType(dtype, spark_type)
# Note that, DataFrame type hints will create a Tuple.
# Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
# Check if the name is Tuple.
name = getattr(tpe, "_name", getattr(tpe, "__name__", None))
if name == "Tuple":
tuple_type = tpe
if hasattr(tuple_type, "__tuple_params__"):
# Python 3.5.0 to 3.5.2 has '__tuple_params__' instead.
# See https://github.com/python/cpython/blob/v3.5.2/Lib/typing.py
parameters = getattr(tuple_type, "__tuple_params__")
else:
parameters = getattr(tuple_type, "__args__")
dtypes, spark_types = zip(
*(
koalas_dtype(p.tpe)
if isclass(p) and issubclass(p, NameTypeHolder)
else koalas_dtype(p)
for p in parameters
)
)
names = [
p.name if isclass(p) and issubclass(p, NameTypeHolder) else None for p in parameters
]
return DataFrameType(list(dtypes), list(spark_types), names)
types = koalas_dtype(tpe)
if types is None:
return UnknownType(tpe)
else:
return ScalarType(*types)

View file

@ -0,0 +1,269 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import importlib
import inspect
import threading
import time
from types import ModuleType
from typing import Union
import pandas as pd
from pyspark.pandas import config, namespace, sql
from pyspark.pandas.accessors import KoalasFrameMethods
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.datetimes import DatetimeMethods
from pyspark.pandas.groupby import DataFrameGroupBy, SeriesGroupBy
from pyspark.pandas.indexes.base import Index
from pyspark.pandas.indexes.category import CategoricalIndex
from pyspark.pandas.indexes.datetimes import DatetimeIndex
from pyspark.pandas.indexes.multi import MultiIndex
from pyspark.pandas.indexes.numeric import Float64Index, Int64Index
from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame
from pyspark.pandas.missing.groupby import (
MissingPandasLikeDataFrameGroupBy,
MissingPandasLikeSeriesGroupBy,
)
from pyspark.pandas.missing.indexes import (
MissingPandasLikeCategoricalIndex,
MissingPandasLikeDatetimeIndex,
MissingPandasLikeIndex,
MissingPandasLikeMultiIndex,
)
from pyspark.pandas.missing.series import MissingPandasLikeSeries
from pyspark.pandas.missing.window import (
MissingPandasLikeExpanding,
MissingPandasLikeRolling,
MissingPandasLikeExpandingGroupby,
MissingPandasLikeRollingGroupby,
)
from pyspark.pandas.series import Series
from pyspark.pandas.spark.accessors import (
CachedSparkFrameMethods,
SparkFrameMethods,
SparkIndexOpsMethods,
)
from pyspark.pandas.strings import StringMethods
from pyspark.pandas.window import Expanding, ExpandingGroupby, Rolling, RollingGroupby
def attach(logger_module: Union[str, ModuleType]) -> None:
"""
Attach the usage logger.
Parameters
----------
logger_module : the module or module name contains the usage logger.
The module needs to provide `get_logger` function as an entry point of the plug-in
returning the usage logger.
See Also
--------
usage_logger : the reference implementation of the usage logger.
"""
if isinstance(logger_module, str):
logger_module = importlib.import_module(logger_module)
logger = getattr(logger_module, "get_logger")()
modules = [config, namespace]
classes = [
DataFrame,
Series,
Index,
MultiIndex,
Int64Index,
Float64Index,
CategoricalIndex,
DatetimeIndex,
DataFrameGroupBy,
SeriesGroupBy,
DatetimeMethods,
StringMethods,
Expanding,
ExpandingGroupby,
Rolling,
RollingGroupby,
CachedSparkFrameMethods,
SparkFrameMethods,
SparkIndexOpsMethods,
KoalasFrameMethods,
]
try:
from pyspark.pandas import mlflow
modules.append(mlflow)
classes.append(mlflow.PythonModelWrapper)
except ImportError:
pass
sql._CAPTURE_SCOPES = 3 # type: ignore
modules.append(sql) # type: ignore
# Modules
for target_module in modules:
target_name = target_module.__name__.split(".")[-1]
for name in getattr(target_module, "__all__"):
func = getattr(target_module, name)
if not inspect.isfunction(func):
continue
setattr(target_module, name, _wrap_function(target_name, name, func, logger))
special_functions = set(
[
"__init__",
"__repr__",
"__str__",
"_repr_html_",
"__len__",
"__getitem__",
"__setitem__",
"__getattr__",
]
)
# Classes
for target_class in classes:
for name, func in inspect.getmembers(target_class, inspect.isfunction):
if name.startswith("_") and name not in special_functions:
continue
setattr(target_class, name, _wrap_function(target_class.__name__, name, func, logger))
for name, prop in inspect.getmembers(target_class, lambda o: isinstance(o, property)):
if name.startswith("_"):
continue
setattr(target_class, name, _wrap_property(target_class.__name__, name, prop, logger))
# Missings
for original, missing in [
(pd.DataFrame, _MissingPandasLikeDataFrame),
(pd.Series, MissingPandasLikeSeries),
(pd.Index, MissingPandasLikeIndex),
(pd.MultiIndex, MissingPandasLikeMultiIndex),
(pd.CategoricalIndex, MissingPandasLikeCategoricalIndex),
(pd.DatetimeIndex, MissingPandasLikeDatetimeIndex),
(pd.core.groupby.DataFrameGroupBy, MissingPandasLikeDataFrameGroupBy),
(pd.core.groupby.SeriesGroupBy, MissingPandasLikeSeriesGroupBy),
(pd.core.window.Expanding, MissingPandasLikeExpanding),
(pd.core.window.Rolling, MissingPandasLikeRolling),
(pd.core.window.ExpandingGroupby, MissingPandasLikeExpandingGroupby),
(pd.core.window.RollingGroupby, MissingPandasLikeRollingGroupby),
]:
for name, func in inspect.getmembers(missing, inspect.isfunction):
setattr(
missing,
name,
_wrap_missing_function(original.__name__, name, func, original, logger),
)
for name, prop in inspect.getmembers(missing, lambda o: isinstance(o, property)):
setattr(missing, name, _wrap_missing_property(original.__name__, name, prop, logger))
_local = threading.local()
def _wrap_function(class_name, function_name, func, logger):
signature = inspect.signature(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
if hasattr(_local, "logging") and _local.logging:
# no need to log since this should be internal call.
return func(*args, **kwargs)
_local.logging = True
try:
start = time.perf_counter()
try:
res = func(*args, **kwargs)
logger.log_success(
class_name, function_name, time.perf_counter() - start, signature
)
return res
except Exception as ex:
logger.log_failure(
class_name, function_name, ex, time.perf_counter() - start, signature
)
raise
finally:
_local.logging = False
return wrapper
def _wrap_property(class_name, property_name, prop, logger):
@property
def wrapper(self):
if hasattr(_local, "logging") and _local.logging:
# no need to log since this should be internal call.
return prop.fget(self)
_local.logging = True
try:
start = time.perf_counter()
try:
res = prop.fget(self)
logger.log_success(class_name, property_name, time.perf_counter() - start)
return res
except Exception as ex:
logger.log_failure(class_name, property_name, ex, time.perf_counter() - start)
raise
finally:
_local.logging = False
wrapper.__doc__ = prop.__doc__
if prop.fset is not None:
wrapper = wrapper.setter(_wrap_function(class_name, prop.fset.__name__, prop.fset, logger))
return wrapper
def _wrap_missing_function(class_name, function_name, func, original, logger):
if not hasattr(original, function_name):
return func
signature = inspect.signature(getattr(original, function_name))
is_deprecated = func.__name__ == "deprecated_function"
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
logger.log_missing(class_name, function_name, is_deprecated, signature)
return wrapper
def _wrap_missing_property(class_name, property_name, prop, logger):
is_deprecated = prop.fget.__name__ == "deprecated_property"
@property
def wrapper(self):
try:
return prop.fget(self)
finally:
logger.log_missing(class_name, property_name, is_deprecated)
return wrapper

View file

@ -0,0 +1,132 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The reference implementation of usage logger using the Python standard logging library.
"""
from inspect import Signature
import logging
from typing import Any, Optional
def get_logger() -> Any:
""" An entry point of the plug-in and return the usage logger. """
return KoalasUsageLogger()
def _format_signature(signature):
return (
"({})".format(", ".join([p.name for p in signature.parameters.values()]))
if signature is not None
else ""
)
class KoalasUsageLogger(object):
"""
The reference implementation of usage logger.
The usage logger needs to provide the following methods:
- log_success(self, class_name, name, duration, signature=None)
- log_failure(self, class_name, name, ex, duration, signature=None)
- log_missing(self, class_name, name, is_deprecated=False, signature=None)
"""
def __init__(self):
self.logger = logging.getLogger("pyspark.pandas.usage_logger")
def log_success(
self, class_name: str, name: str, duration: float, signature: Optional[Signature] = None
) -> None:
"""
Log the function or property call is successfully finished.
:param class_name: the target class name
:param name: the target function or property name
:param duration: the duration to finish the function or property call
:param signature: the signature if the target is a function, else None
"""
if self.logger.isEnabledFor(logging.INFO):
msg = (
"A {function} `{class_name}.{name}{signature}` was successfully finished "
"after {duration:.3f} ms."
).format(
class_name=class_name,
name=name,
signature=_format_signature(signature),
duration=duration * 1000,
function="function" if signature is not None else "property",
)
self.logger.info(msg)
def log_failure(
self,
class_name: str,
name: str,
ex: Exception,
duration: float,
signature: Optional[Signature] = None,
) -> None:
"""
Log the function or property call failed.
:param class_name: the target class name
:param name: the target function or property name
:param ex: the exception causing the failure
:param duration: the duration until the function or property call fails
:param signature: the signature if the target is a function, else None
"""
if self.logger.isEnabledFor(logging.WARNING):
msg = (
"A {function} `{class_name}.{name}{signature}` was failed "
"after {duration:.3f} ms: {msg}"
).format(
class_name=class_name,
name=name,
signature=_format_signature(signature),
msg=str(ex),
duration=duration * 1000,
function="function" if signature is not None else "property",
)
self.logger.warning(msg)
def log_missing(
self,
class_name: str,
name: str,
is_deprecated: bool = False,
signature: Optional[Signature] = None,
) -> None:
"""
Log the missing or deprecated function or property is called.
:param class_name: the target class name
:param name: the target function or property name
:param is_deprecated: True if the function or property is marked as deprecated
:param signature: the original function signature if the target is a function, else None
"""
if self.logger.isEnabledFor(logging.INFO):
msg = "A {deprecated} {function} `{class_name}.{name}{signature}` was called.".format(
class_name=class_name,
name=name,
signature=_format_signature(signature),
function="function" if signature is not None else "property",
deprecated="deprecated" if is_deprecated else "missing",
)
self.logger.info(msg)

View file

@ -0,0 +1,878 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Commonly used utils in Koalas.
"""
import functools
from collections import OrderedDict
from contextlib import contextmanager
from distutils.version import LooseVersion
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
import warnings
import pyarrow
import pyspark
from pyspark import sql as spark
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType
import pandas as pd
from pandas.api.types import is_list_like
# For running doctests and reference resolution in PyCharm.
from pyspark import pandas as pp # noqa: F401
from pyspark.pandas.typedef.typehints import (
as_spark_type,
extension_dtypes,
spark_type_to_pandas_dtype,
)
if TYPE_CHECKING:
# This is required in old Python 3.5 to prevent circular reference.
from pyspark.pandas.base import IndexOpsMixin # noqa: F401 (SPARK-34943)
from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
from pyspark.pandas.internal import InternalFrame # noqa: F401 (SPARK-34943)
ERROR_MESSAGE_CANNOT_COMBINE = (
"Cannot combine the series or dataframe because it comes from a different dataframe. "
"In order to allow this operation, enable 'compute.ops_on_diff_frames' option."
)
if LooseVersion(pyspark.__version__) < LooseVersion("3.0"):
SPARK_CONF_ARROW_ENABLED = "spark.sql.execution.arrow.enabled"
else:
SPARK_CONF_ARROW_ENABLED = "spark.sql.execution.arrow.pyspark.enabled"
def same_anchor(
this: Union["DataFrame", "IndexOpsMixin", "InternalFrame"],
that: Union["DataFrame", "IndexOpsMixin", "InternalFrame"],
) -> bool:
"""
Check if the anchors of the given DataFrame or Series are the same or not.
"""
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.internal import InternalFrame
if isinstance(this, InternalFrame):
this_internal = this
else:
assert isinstance(this, (DataFrame, IndexOpsMixin)), type(this)
this_internal = this._internal
if isinstance(that, InternalFrame):
that_internal = that
else:
assert isinstance(that, (DataFrame, IndexOpsMixin)), type(that)
that_internal = that._internal
return (
this_internal.spark_frame is that_internal.spark_frame
and this_internal.index_level == that_internal.index_level
and all(
this_scol._jc.equals(that_scol._jc)
for this_scol, that_scol in zip(
this_internal.index_spark_columns, that_internal.index_spark_columns
)
)
)
def combine_frames(this, *args, how="full", preserve_order_column=False):
"""
This method combines `this` DataFrame with a different `that` DataFrame or
Series from a different DataFrame.
It returns a DataFrame that has prefix `this_` and `that_` to distinct
the columns names from both DataFrames
It internally performs a join operation which can be expensive in general.
So, if `compute.ops_on_diff_frames` option is False,
this method throws an exception.
"""
from pyspark.pandas.config import get_option
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.internal import (
InternalFrame,
HIDDEN_COLUMNS,
NATURAL_ORDER_COLUMN_NAME,
SPARK_INDEX_NAME_FORMAT,
)
from pyspark.pandas.series import Series
if all(isinstance(arg, Series) for arg in args):
assert all(
same_anchor(arg, args[0]) for arg in args
), "Currently only one different DataFrame (from given Series) is supported"
assert not same_anchor(this, args[0]), "We don't need to combine. All series is in this."
that = args[0]._kdf[list(args)]
elif len(args) == 1 and isinstance(args[0], DataFrame):
assert isinstance(args[0], DataFrame)
assert not same_anchor(
this, args[0]
), "We don't need to combine. `this` and `that` are same."
that = args[0]
else:
raise AssertionError("args should be single DataFrame or " "single/multiple Series")
if get_option("compute.ops_on_diff_frames"):
def resolve(internal, side):
rename = lambda col: "__{}_{}".format(side, col)
internal = internal.resolved_copy
sdf = internal.spark_frame
sdf = internal.spark_frame.select(
[
scol_for(sdf, col).alias(rename(col))
for col in sdf.columns
if col not in HIDDEN_COLUMNS
]
+ list(HIDDEN_COLUMNS)
)
return internal.copy(
spark_frame=sdf,
index_spark_columns=[
scol_for(sdf, rename(col)) for col in internal.index_spark_column_names
],
data_spark_columns=[
scol_for(sdf, rename(col)) for col in internal.data_spark_column_names
],
)
this_internal = resolve(this._internal, "this")
that_internal = resolve(that._internal, "that")
this_index_map = list(
zip(
this_internal.index_spark_column_names,
this_internal.index_names,
this_internal.index_dtypes,
)
)
that_index_map = list(
zip(
that_internal.index_spark_column_names,
that_internal.index_names,
that_internal.index_dtypes,
)
)
assert len(this_index_map) == len(that_index_map)
join_scols = []
merged_index_scols = []
# Note that the order of each element in index_map is guaranteed according to the index
# level.
this_and_that_index_map = list(zip(this_index_map, that_index_map))
this_sdf = this_internal.spark_frame.alias("this")
that_sdf = that_internal.spark_frame.alias("that")
# If the same named index is found, that's used.
index_column_names = []
index_use_extension_dtypes = []
for (
i,
((this_column, this_name, this_dtype), (that_column, that_name, that_dtype)),
) in enumerate(this_and_that_index_map):
if this_name == that_name:
# We should merge the Spark columns into one
# to mimic pandas' behavior.
this_scol = scol_for(this_sdf, this_column)
that_scol = scol_for(that_sdf, that_column)
join_scol = this_scol == that_scol
join_scols.append(join_scol)
column_name = SPARK_INDEX_NAME_FORMAT(i)
index_column_names.append(column_name)
index_use_extension_dtypes.append(
any(isinstance(dtype, extension_dtypes) for dtype in [this_dtype, that_dtype])
)
merged_index_scols.append(
F.when(this_scol.isNotNull(), this_scol).otherwise(that_scol).alias(column_name)
)
else:
raise ValueError("Index names must be exactly matched currently.")
assert len(join_scols) > 0, "cannot join with no overlapping index names"
joined_df = this_sdf.join(that_sdf, on=join_scols, how=how)
if preserve_order_column:
order_column = [scol_for(this_sdf, NATURAL_ORDER_COLUMN_NAME)]
else:
order_column = []
joined_df = joined_df.select(
merged_index_scols
+ [
scol_for(this_sdf, this_internal.spark_column_name_for(label))
for label in this_internal.column_labels
]
+ [
scol_for(that_sdf, that_internal.spark_column_name_for(label))
for label in that_internal.column_labels
]
+ order_column
)
index_spark_columns = [scol_for(joined_df, col) for col in index_column_names]
index_dtypes = [
spark_type_to_pandas_dtype(field.dataType, use_extension_dtypes=use_extension_dtypes)
for field, use_extension_dtypes in zip(
joined_df.select(index_spark_columns).schema, index_use_extension_dtypes
)
]
index_columns = set(index_column_names)
new_data_columns = [
col
for col in joined_df.columns
if col not in index_columns and col != NATURAL_ORDER_COLUMN_NAME
]
data_dtypes = this_internal.data_dtypes + that_internal.data_dtypes
level = max(this_internal.column_labels_level, that_internal.column_labels_level)
def fill_label(label):
if label is None:
return ([""] * (level - 1)) + [None]
else:
return ([""] * (level - len(label))) + list(label)
column_labels = [
tuple(["this"] + fill_label(label)) for label in this_internal.column_labels
] + [tuple(["that"] + fill_label(label)) for label in that_internal.column_labels]
column_label_names = (
[None] * (1 + level - this_internal.column_labels_level)
) + this_internal.column_label_names
return DataFrame(
InternalFrame(
spark_frame=joined_df,
index_spark_columns=index_spark_columns,
index_names=this_internal.index_names,
index_dtypes=index_dtypes,
column_labels=column_labels,
data_spark_columns=[scol_for(joined_df, col) for col in new_data_columns],
data_dtypes=data_dtypes,
column_label_names=column_label_names,
)
)
else:
raise ValueError(ERROR_MESSAGE_CANNOT_COMBINE)
def align_diff_frames(
resolve_func,
this: "DataFrame",
that: "DataFrame",
fillna: bool = True,
how: str = "full",
preserve_order_column: bool = False,
) -> "DataFrame":
"""
This method aligns two different DataFrames with a given `func`. Columns are resolved and
handled within the given `func`.
To use this, `compute.ops_on_diff_frames` should be True, for now.
:param resolve_func: Takes aligned (joined) DataFrame, the column of the current DataFrame, and
the column of another DataFrame. It returns an iterable that produces Series.
>>> from pyspark.pandas.config import set_option, reset_option
>>>
>>> set_option("compute.ops_on_diff_frames", True)
>>>
>>> kdf1 = pp.DataFrame({'a': [9, 8, 7, 6, 5, 4, 3, 2, 1]})
>>> kdf2 = pp.DataFrame({'a': [9, 8, 7, 6, 5, 4, 3, 2, 1]})
>>>
>>> def func(kdf, this_column_labels, that_column_labels):
... kdf # conceptually this is A + B.
...
... # Within this function, Series from A or B can be performed against `kdf`.
... this_label = this_column_labels[0] # this is ('a',) from kdf1.
... that_label = that_column_labels[0] # this is ('a',) from kdf2.
... new_series = (kdf[this_label] - kdf[that_label]).rename(str(this_label))
...
... # This new series will be placed in new DataFrame.
... yield (new_series, this_label)
>>>
>>>
>>> align_diff_frames(func, kdf1, kdf2).sort_index()
a
0 0
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
>>> reset_option("compute.ops_on_diff_frames")
:param this: a DataFrame to align
:param that: another DataFrame to align
:param fillna: If True, it fills missing values in non-common columns in both `this` and `that`.
Otherwise, it returns as are.
:param how: join way. In addition, it affects how `resolve_func` resolves the column conflict.
- full: `resolve_func` should resolve only common columns from 'this' and 'that' DataFrames.
For instance, if 'this' has columns A, B, C and that has B, C, D, `this_columns` and
'that_columns' in this function are B, C and B, C.
- left: `resolve_func` should resolve columns including that columns.
For instance, if 'this' has columns A, B, C and that has B, C, D, `this_columns` is
B, C but `that_columns` are B, C, D.
- inner: Same as 'full' mode; however, internally performs inner join instead.
:return: Aligned DataFrame
"""
from pyspark.pandas.frame import DataFrame
assert how == "full" or how == "left" or how == "inner"
this_column_labels = this._internal.column_labels
that_column_labels = that._internal.column_labels
common_column_labels = set(this_column_labels).intersection(that_column_labels)
# 1. Perform the join given two dataframes.
combined = combine_frames(this, that, how=how, preserve_order_column=preserve_order_column)
# 2. Apply the given function to transform the columns in a batch and keep the new columns.
combined_column_labels = combined._internal.column_labels
that_columns_to_apply = []
this_columns_to_apply = []
additional_that_columns = []
columns_to_keep = []
column_labels_to_keep = []
for combined_label in combined_column_labels:
for common_label in common_column_labels:
if combined_label == tuple(["this", *common_label]):
this_columns_to_apply.append(combined_label)
break
elif combined_label == tuple(["that", *common_label]):
that_columns_to_apply.append(combined_label)
break
else:
if how == "left" and combined_label in [
tuple(["that", *label]) for label in that_column_labels
]:
# In this case, we will drop `that_columns` in `columns_to_keep` but passes
# it later to `func`. `func` should resolve it.
# Note that adding this into a separate list (`additional_that_columns`)
# is intentional so that `this_columns` and `that_columns` can be paired.
additional_that_columns.append(combined_label)
elif fillna:
columns_to_keep.append(F.lit(None).cast(DoubleType()).alias(str(combined_label)))
column_labels_to_keep.append(combined_label)
else:
columns_to_keep.append(combined._kser_for(combined_label))
column_labels_to_keep.append(combined_label)
that_columns_to_apply += additional_that_columns
# Should extract columns to apply and do it in a batch in case
# it adds new columns for example.
if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0:
kser_set, column_labels_applied = zip(
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
)
columns_applied = list(kser_set)
column_labels_applied = list(column_labels_applied)
else:
columns_applied = []
column_labels_applied = []
applied = DataFrame(
combined._internal.with_new_columns(
columns_applied + columns_to_keep,
column_labels=column_labels_applied + column_labels_to_keep,
)
) # type: DataFrame
# 3. Restore the names back and deduplicate columns.
this_labels = OrderedDict()
# Add columns in an order of its original frame.
for this_label in this_column_labels:
for new_label in applied._internal.column_labels:
if new_label[1:] not in this_labels and this_label == new_label[1:]:
this_labels[new_label[1:]] = new_label
# After that, we will add the rest columns.
other_labels = OrderedDict()
for new_label in applied._internal.column_labels:
if new_label[1:] not in this_labels:
other_labels[new_label[1:]] = new_label
kdf = applied[list(this_labels.values()) + list(other_labels.values())]
kdf.columns = kdf.columns.droplevel()
return kdf
def is_testing():
""" Indicates whether Koalas is currently running tests. """
return "KOALAS_TESTING" in os.environ
def default_session(conf=None):
if conf is None:
conf = dict()
should_use_legacy_ipc = False
if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15") and LooseVersion(
pyspark.__version__
) < LooseVersion("3.0"):
conf["spark.executorEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
conf["spark.yarn.appMasterEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
conf["spark.mesos.driverEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
conf["spark.kubernetes.driverEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
should_use_legacy_ipc = True
builder = spark.SparkSession.builder.appName("Koalas")
for key, value in conf.items():
builder = builder.config(key, value)
# Currently, Koalas is dependent on such join due to 'compute.ops_on_diff_frames'
# configuration. This is needed with Spark 3.0+.
builder.config("spark.sql.analyzer.failAmbiguousSelfJoin", False)
if LooseVersion(pyspark.__version__) >= LooseVersion("3.0.1") and is_testing():
builder.config("spark.executor.allowSparkContext", False)
session = builder.getOrCreate()
if not should_use_legacy_ipc:
is_legacy_ipc_set = any(
v == "1"
for v in [
session.conf.get("spark.executorEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
session.conf.get("spark.yarn.appMasterEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
session.conf.get("spark.mesos.driverEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
session.conf.get("spark.kubernetes.driverEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
]
)
if is_legacy_ipc_set:
raise RuntimeError(
"Please explicitly unset 'ARROW_PRE_0_15_IPC_FORMAT' environment variable in "
"both driver and executor sides. Check your spark.executorEnv.*, "
"spark.yarn.appMasterEnv.*, spark.mesos.driverEnv.* and "
"spark.kubernetes.driverEnv.* configurations. It is required to set this "
"environment variable only when you use pyarrow>=0.15 and pyspark<3.0."
)
return session
@contextmanager
def sql_conf(pairs, *, spark=None):
"""
A convenient context manager to set `value` to the Spark SQL configuration `key` and
then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."
if spark is None:
spark = default_session()
keys = pairs.keys()
new_values = pairs.values()
old_values = [spark.conf.get(key, None) for key in keys]
for key, new_value in zip(keys, new_values):
spark.conf.set(key, new_value)
try:
yield
finally:
for key, old_value in zip(keys, old_values):
if old_value is None:
spark.conf.unset(key)
else:
spark.conf.set(key, old_value)
def validate_arguments_and_invoke_function(
pobj: Union[pd.DataFrame, pd.Series],
koalas_func: Callable,
pandas_func: Callable,
input_args: Dict,
):
"""
Invokes a pandas function.
This is created because different versions of pandas support different parameters, and as a
result when we code against the latest version, our users might get a confusing
"got an unexpected keyword argument" error if they are using an older version of pandas.
This function validates all the arguments, removes the ones that are not supported if they
are simply the default value (i.e. most likely the user didn't explicitly specify it). It
throws a TypeError if the user explicitly specify an argument that is not supported by the
pandas version available.
For example usage, look at DataFrame.to_html().
:param pobj: the pandas DataFrame or Series to operate on
:param koalas_func: Koalas function, used to get default parameter values
:param pandas_func: pandas function, used to check whether pandas supports all the arguments
:param input_args: arguments to pass to the pandas function, often created by using locals().
Make sure locals() call is at the top of the function so it captures only
input parameters, rather than local variables.
:return: whatever pandas_func returns
"""
import inspect
# Makes a copy since whatever passed in is likely created by locals(), and we can't delete
# 'self' key from that.
args = input_args.copy()
del args["self"]
if "kwargs" in args:
# explode kwargs
kwargs = args["kwargs"]
del args["kwargs"]
args = {**args, **kwargs}
koalas_params = inspect.signature(koalas_func).parameters
pandas_params = inspect.signature(pandas_func).parameters
for param in koalas_params.values():
if param.name not in pandas_params:
if args[param.name] == param.default:
del args[param.name]
else:
raise TypeError(
(
"The pandas version [%s] available does not support parameter '%s' "
+ "for function '%s'."
)
% (pd.__version__, param.name, pandas_func.__name__)
)
args["self"] = pobj
return pandas_func(**args)
def lazy_property(fn):
"""
Decorator that makes a property lazy-evaluated.
Copied from https://stevenloria.com/lazy-properties/
"""
attr_name = "_lazy_" + fn.__name__
@property
@functools.wraps(fn)
def wrapped_lazy_property(self):
if not hasattr(self, attr_name):
setattr(self, attr_name, fn(self))
return getattr(self, attr_name)
def deleter(self):
if hasattr(self, attr_name):
delattr(self, attr_name)
return wrapped_lazy_property.deleter(deleter)
def scol_for(sdf: spark.DataFrame, column_name: str) -> spark.Column:
""" Return Spark Column for the given column name. """
return sdf["`{}`".format(column_name)]
def column_labels_level(column_labels: List[Tuple]) -> int:
""" Return the level of the column index. """
if len(column_labels) == 0:
return 1
else:
levels = set(1 if label is None else len(label) for label in column_labels)
assert len(levels) == 1, levels
return list(levels)[0]
def name_like_string(name: Optional[Union[str, Tuple]]) -> str:
"""
Return the name-like strings from str or tuple of str
Examples
--------
>>> name = 'abc'
>>> name_like_string(name)
'abc'
>>> name = ('abc',)
>>> name_like_string(name)
'abc'
>>> name = ('a', 'b', 'c')
>>> name_like_string(name)
'(a, b, c)'
"""
if name is None:
name = ("__none__",)
elif is_list_like(name):
name = tuple([str(n) for n in name])
else:
name = (str(name),)
return ("(%s)" % ", ".join(name)) if len(name) > 1 else name[0]
def is_name_like_tuple(value: Any, allow_none: bool = True, check_type: bool = False) -> bool:
"""
Check the given tuple is be able to be used as a name.
Examples
--------
>>> is_name_like_tuple(('abc',))
True
>>> is_name_like_tuple((1,))
True
>>> is_name_like_tuple(('abc', 1, None))
True
>>> is_name_like_tuple(('abc', 1, None), check_type=True)
True
>>> is_name_like_tuple((1.0j,))
True
>>> is_name_like_tuple(tuple())
False
>>> is_name_like_tuple((list('abc'),))
False
>>> is_name_like_tuple(('abc', 1, None), allow_none=False)
False
>>> is_name_like_tuple((1.0j,), check_type=True)
False
"""
if value is None:
return allow_none
elif not isinstance(value, tuple):
return False
elif len(value) == 0:
return False
elif not allow_none and any(v is None for v in value):
return False
elif any(is_list_like(v) or isinstance(v, slice) for v in value):
return False
elif check_type:
return all(
v is None or as_spark_type(type(v), raise_error=False) is not None for v in value
)
else:
return True
def is_name_like_value(
value: Any, allow_none: bool = True, allow_tuple: bool = True, check_type: bool = False
) -> bool:
"""
Check the given value is like a name.
Examples
--------
>>> is_name_like_value('abc')
True
>>> is_name_like_value(1)
True
>>> is_name_like_value(None)
True
>>> is_name_like_value(('abc',))
True
>>> is_name_like_value(1.0j)
True
>>> is_name_like_value(list('abc'))
False
>>> is_name_like_value(None, allow_none=False)
False
>>> is_name_like_value(('abc',), allow_tuple=False)
False
>>> is_name_like_value(1.0j, check_type=True)
False
"""
if value is None:
return allow_none
elif isinstance(value, tuple):
return allow_tuple and is_name_like_tuple(
value, allow_none=allow_none, check_type=check_type
)
elif is_list_like(value) or isinstance(value, slice):
return False
elif check_type:
return as_spark_type(type(value), raise_error=False) is not None
else:
return True
def validate_axis(axis=0, none_axis=0):
""" Check the given axis is valid. """
# convert to numeric axis
axis = {None: none_axis, "index": 0, "columns": 1}.get(axis, axis)
if axis not in (none_axis, 0, 1):
raise ValueError("No axis named {0}".format(axis))
return axis
def validate_bool_kwarg(value, arg_name):
""" Ensures that argument passed in arg_name is of type bool. """
if not (isinstance(value, bool) or value is None):
raise ValueError(
'For argument "{}" expected type bool, received '
"type {}.".format(arg_name, type(value).__name__)
)
return value
def validate_how(how: str) -> str:
""" Check the given how for join is valid. """
if how == "full":
warnings.warn(
"Warning: While Koalas will accept 'full', you should use 'outer' "
+ "instead to be compatible with the pandas merge API",
UserWarning,
)
if how == "outer":
# 'outer' in pandas equals 'full' in Spark
how = "full"
if how not in ("inner", "left", "right", "full"):
raise ValueError(
"The 'how' parameter has to be amongst the following values: ",
"['inner', 'left', 'right', 'outer']",
)
return how
def verify_temp_column_name(
df: Union["DataFrame", spark.DataFrame], column_name_or_label: Union[Any, Tuple]
) -> Union[Any, Tuple]:
"""
Verify that the given column name does not exist in the given Koalas or Spark DataFrame.
The temporary column names should start and end with `__`. In addition, `column_name_or_label`
expects a single string, or column labels when `df` is a Koalas DataFrame.
>>> kdf = pp.DataFrame({("x", "a"): ['a', 'b', 'c']})
>>> kdf["__dummy__"] = 0
>>> kdf[("", "__dummy__")] = 1
>>> kdf # doctest: +NORMALIZE_WHITESPACE
x __dummy__
a __dummy__
0 a 0 1
1 b 0 1
2 c 0 1
>>> verify_temp_column_name(kdf, '__tmp__')
('__tmp__', '')
>>> verify_temp_column_name(kdf, ('', '__tmp__'))
('', '__tmp__')
>>> verify_temp_column_name(kdf, '__dummy__')
Traceback (most recent call last):
...
AssertionError: ... `(__dummy__, )` ...
>>> verify_temp_column_name(kdf, ('', '__dummy__'))
Traceback (most recent call last):
...
AssertionError: ... `(, __dummy__)` ...
>>> verify_temp_column_name(kdf, 'dummy')
Traceback (most recent call last):
...
AssertionError: ... should be empty or start and end with `__`: ('dummy', '')
>>> verify_temp_column_name(kdf, ('', 'dummy'))
Traceback (most recent call last):
...
AssertionError: ... should be empty or start and end with `__`: ('', 'dummy')
>>> internal = kdf._internal.resolved_copy
>>> sdf = internal.spark_frame
>>> sdf.select(internal.data_spark_columns).show() # doctest: +NORMALIZE_WHITESPACE
+------+---------+-------------+
|(x, a)|__dummy__|(, __dummy__)|
+------+---------+-------------+
| a| 0| 1|
| b| 0| 1|
| c| 0| 1|
+------+---------+-------------+
>>> verify_temp_column_name(sdf, '__tmp__')
'__tmp__'
>>> verify_temp_column_name(sdf, '__dummy__')
Traceback (most recent call last):
...
AssertionError: ... `__dummy__` ... '(x, a)', '__dummy__', '(, __dummy__)', ...
>>> verify_temp_column_name(sdf, ('', '__dummy__'))
Traceback (most recent call last):
...
AssertionError: <class 'tuple'>
>>> verify_temp_column_name(sdf, 'dummy')
Traceback (most recent call last):
...
AssertionError: ... should start and end with `__`: dummy
"""
from pyspark.pandas.frame import DataFrame
if isinstance(df, DataFrame):
if isinstance(column_name_or_label, str):
column_name = column_name_or_label
level = df._internal.column_labels_level
column_name_or_label = tuple([column_name_or_label] + ([""] * (level - 1)))
else:
column_name = name_like_string(column_name_or_label)
assert any(len(label) > 0 for label in column_name_or_label) and all(
label == "" or (label.startswith("__") and label.endswith("__"))
for label in column_name_or_label
), "The temporary column name should be empty or start and end with `__`: {}".format(
column_name_or_label
)
assert all(
column_name_or_label != label for label in df._internal.column_labels
), "The given column name `{}` already exists in the Koalas DataFrame: {}".format(
name_like_string(column_name_or_label), df.columns
)
df = df._internal.resolved_copy.spark_frame
else:
assert isinstance(column_name_or_label, str), type(column_name_or_label)
assert column_name_or_label.startswith("__") and column_name_or_label.endswith(
"__"
), "The temporary column name should start and end with `__`: {}".format(
column_name_or_label
)
column_name = column_name_or_label
assert isinstance(df, spark.DataFrame), type(df)
assert (
column_name not in df.columns
), "The given column name `{}` already exists in the Spark DataFrame: {}".format(
column_name, df.columns
)
return column_name_or_label
def compare_null_first(left, right, comp):
return (left.isNotNull() & right.isNotNull() & comp(left, right)) | (
left.isNull() & right.isNotNull()
)
def compare_null_last(left, right, comp):
return (left.isNotNull() & right.isNotNull() & comp(left, right)) | (
left.isNotNull() & right.isNull()
)
def compare_disallow_null(left, right, comp):
return left.isNotNull() & right.isNotNull() & comp(left, right)
def compare_allow_null(left, right, comp):
return left.isNull() | right.isNull() | comp(left, right)

View file

@ -0,0 +1,18 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__version__ = "1.7.0"

File diff suppressed because it is too large Load diff