[SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe

## What changes were proposed in this pull request?

This PR proposes to support unicodes in Param methods in ML, other missed functions in DataFrame.

For example, this causes a `ValueError` in Python 2.x when param is a unicode string:

```python
>>> from pyspark.ml.classification import LogisticRegression
>>> lr = LogisticRegression()
>>> lr.hasParam("threshold")
True
>>> lr.hasParam(u"threshold")
Traceback (most recent call last):
 ...
    raise TypeError("hasParam(): paramName must be a string")
TypeError: hasParam(): paramName must be a string
```

This PR is based on https://github.com/apache/spark/pull/13036

## How was this patch tested?

Unit tests in `python/pyspark/ml/tests.py` and `python/pyspark/sql/tests.py`.

Author: hyukjinkwon <gurwls223@gmail.com>
Author: sethah <seth.hendrickson16@gmail.com>

Closes #17096 from HyukjinKwon/SPARK-15243.
This commit is contained in:
hyukjinkwon 2017-09-08 11:57:33 -07:00 committed by Holden Karau
parent 8a4f228dc0
commit 8598d03a00
4 changed files with 42 additions and 21 deletions

View file

@ -330,7 +330,7 @@ class Params(Identifiable):
Tests whether this instance contains a param with a given
(string) name.
"""
if isinstance(paramName, str):
if isinstance(paramName, basestring):
p = getattr(self, paramName, None)
return isinstance(p, Param)
else:
@ -413,7 +413,7 @@ class Params(Identifiable):
if isinstance(param, Param):
self._shouldOwn(param)
return param
elif isinstance(param, str):
elif isinstance(param, basestring):
return self.getParam(param)
else:
raise ValueError("Cannot resolve %r as a param." % param)

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase):
testParams = TestParams()
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
self.assertFalse(testParams.hasParam("notAParameter"))
self.assertTrue(testParams.hasParam(u"maxIter"))
def test_resolveparam(self):
testParams = TestParams()
self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
if sys.version_info[0] >= 3:
# In Python 3, it is allowed to get/set attributes with non-ascii characters.
e_cls = AttributeError
else:
e_cls = UnicodeEncodeError
self.assertRaises(e_cls, lambda: testParams._resolveParam(u""))
def test_params(self):
testParams = TestParams()

View file

@ -748,7 +748,7 @@ class DataFrame(object):
+---+-----+
"""
if not isinstance(col, str):
if not isinstance(col, basestring):
raise ValueError("col must be a string, but got %r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" % type(fractions))
@ -1664,18 +1664,18 @@ class DataFrame(object):
Added support for multiple columns.
"""
if not isinstance(col, (str, list, tuple)):
if not isinstance(col, (basestring, list, tuple)):
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
isStr = isinstance(col, str)
isStr = isinstance(col, basestring)
if isinstance(col, tuple):
col = list(col)
elif isinstance(col, str):
elif isStr:
col = [col]
for c in col:
if not isinstance(c, str):
if not isinstance(c, basestring):
raise ValueError("columns should be strings, but got %r" % type(c))
col = _to_list(self._sc, col)
@ -1707,9 +1707,9 @@ class DataFrame(object):
:param col2: The name of the second column
:param method: The correlation method. Currently only supports "pearson"
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
if not method:
method = "pearson"
@ -1727,9 +1727,9 @@ class DataFrame(object):
:param col1: The name of the first column
:param col2: The name of the second column
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return self._jdf.stat().cov(col1, col2)
@ -1749,9 +1749,9 @@ class DataFrame(object):
:param col2: The name of the second column. Distinct items will make the column names
of the DataFrame.
"""
if not isinstance(col1, str):
if not isinstance(col1, basestring):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
if not isinstance(col2, basestring):
raise ValueError("col2 should be a string.")
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)

View file

@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase):
def test_approxQuantile(self):
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
for f in ["a", u"a"]:
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
self.assertTrue(all(isinstance(q, float) for q in aq))
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqs, list))
self.assertEqual(len(aqs), 2)
self.assertTrue(isinstance(aqs[0], list))
@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(isinstance(aqs[1], list))
self.assertEqual(len(aqs[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqt, list))
self.assertEqual(len(aqt), 2)
self.assertTrue(isinstance(aqt[0], list))
@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase):
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr("a", "b")
corr = df.stat.corr(u"a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
def test_sampleby(self):
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(sampled.count() == 3)
def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")
cov = df.stat.cov(u"a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab("a", "b").collect()
ct = df.stat.crosstab(u"a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))