[SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe
## What changes were proposed in this pull request? This PR proposes to support unicodes in Param methods in ML, other missed functions in DataFrame. For example, this causes a `ValueError` in Python 2.x when param is a unicode string: ```python >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() >>> lr.hasParam("threshold") True >>> lr.hasParam(u"threshold") Traceback (most recent call last): ... raise TypeError("hasParam(): paramName must be a string") TypeError: hasParam(): paramName must be a string ``` This PR is based on https://github.com/apache/spark/pull/13036 ## How was this patch tested? Unit tests in `python/pyspark/ml/tests.py` and `python/pyspark/sql/tests.py`. Author: hyukjinkwon <gurwls223@gmail.com> Author: sethah <seth.hendrickson16@gmail.com> Closes #17096 from HyukjinKwon/SPARK-15243.
This commit is contained in:
parent
8a4f228dc0
commit
8598d03a00
|
@ -330,7 +330,7 @@ class Params(Identifiable):
|
||||||
Tests whether this instance contains a param with a given
|
Tests whether this instance contains a param with a given
|
||||||
(string) name.
|
(string) name.
|
||||||
"""
|
"""
|
||||||
if isinstance(paramName, str):
|
if isinstance(paramName, basestring):
|
||||||
p = getattr(self, paramName, None)
|
p = getattr(self, paramName, None)
|
||||||
return isinstance(p, Param)
|
return isinstance(p, Param)
|
||||||
else:
|
else:
|
||||||
|
@ -413,7 +413,7 @@ class Params(Identifiable):
|
||||||
if isinstance(param, Param):
|
if isinstance(param, Param):
|
||||||
self._shouldOwn(param)
|
self._shouldOwn(param)
|
||||||
return param
|
return param
|
||||||
elif isinstance(param, str):
|
elif isinstance(param, basestring):
|
||||||
return self.getParam(param)
|
return self.getParam(param)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Cannot resolve %r as a param." % param)
|
raise ValueError("Cannot resolve %r as a param." % param)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
#
|
#
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
# contributor license agreements. See the NOTICE file distributed with
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase):
|
||||||
testParams = TestParams()
|
testParams = TestParams()
|
||||||
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
|
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
|
||||||
self.assertFalse(testParams.hasParam("notAParameter"))
|
self.assertFalse(testParams.hasParam("notAParameter"))
|
||||||
|
self.assertTrue(testParams.hasParam(u"maxIter"))
|
||||||
|
|
||||||
|
def test_resolveparam(self):
|
||||||
|
testParams = TestParams()
|
||||||
|
self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
|
||||||
|
self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
|
||||||
|
|
||||||
|
self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
|
||||||
|
if sys.version_info[0] >= 3:
|
||||||
|
# In Python 3, it is allowed to get/set attributes with non-ascii characters.
|
||||||
|
e_cls = AttributeError
|
||||||
|
else:
|
||||||
|
e_cls = UnicodeEncodeError
|
||||||
|
self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))
|
||||||
|
|
||||||
def test_params(self):
|
def test_params(self):
|
||||||
testParams = TestParams()
|
testParams = TestParams()
|
||||||
|
|
|
@ -748,7 +748,7 @@ class DataFrame(object):
|
||||||
+---+-----+
|
+---+-----+
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(col, str):
|
if not isinstance(col, basestring):
|
||||||
raise ValueError("col must be a string, but got %r" % type(col))
|
raise ValueError("col must be a string, but got %r" % type(col))
|
||||||
if not isinstance(fractions, dict):
|
if not isinstance(fractions, dict):
|
||||||
raise ValueError("fractions must be a dict but got %r" % type(fractions))
|
raise ValueError("fractions must be a dict but got %r" % type(fractions))
|
||||||
|
@ -1664,18 +1664,18 @@ class DataFrame(object):
|
||||||
Added support for multiple columns.
|
Added support for multiple columns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(col, (str, list, tuple)):
|
if not isinstance(col, (basestring, list, tuple)):
|
||||||
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
|
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
|
||||||
|
|
||||||
isStr = isinstance(col, str)
|
isStr = isinstance(col, basestring)
|
||||||
|
|
||||||
if isinstance(col, tuple):
|
if isinstance(col, tuple):
|
||||||
col = list(col)
|
col = list(col)
|
||||||
elif isinstance(col, str):
|
elif isStr:
|
||||||
col = [col]
|
col = [col]
|
||||||
|
|
||||||
for c in col:
|
for c in col:
|
||||||
if not isinstance(c, str):
|
if not isinstance(c, basestring):
|
||||||
raise ValueError("columns should be strings, but got %r" % type(c))
|
raise ValueError("columns should be strings, but got %r" % type(c))
|
||||||
col = _to_list(self._sc, col)
|
col = _to_list(self._sc, col)
|
||||||
|
|
||||||
|
@ -1707,9 +1707,9 @@ class DataFrame(object):
|
||||||
:param col2: The name of the second column
|
:param col2: The name of the second column
|
||||||
:param method: The correlation method. Currently only supports "pearson"
|
:param method: The correlation method. Currently only supports "pearson"
|
||||||
"""
|
"""
|
||||||
if not isinstance(col1, str):
|
if not isinstance(col1, basestring):
|
||||||
raise ValueError("col1 should be a string.")
|
raise ValueError("col1 should be a string.")
|
||||||
if not isinstance(col2, str):
|
if not isinstance(col2, basestring):
|
||||||
raise ValueError("col2 should be a string.")
|
raise ValueError("col2 should be a string.")
|
||||||
if not method:
|
if not method:
|
||||||
method = "pearson"
|
method = "pearson"
|
||||||
|
@ -1727,9 +1727,9 @@ class DataFrame(object):
|
||||||
:param col1: The name of the first column
|
:param col1: The name of the first column
|
||||||
:param col2: The name of the second column
|
:param col2: The name of the second column
|
||||||
"""
|
"""
|
||||||
if not isinstance(col1, str):
|
if not isinstance(col1, basestring):
|
||||||
raise ValueError("col1 should be a string.")
|
raise ValueError("col1 should be a string.")
|
||||||
if not isinstance(col2, str):
|
if not isinstance(col2, basestring):
|
||||||
raise ValueError("col2 should be a string.")
|
raise ValueError("col2 should be a string.")
|
||||||
return self._jdf.stat().cov(col1, col2)
|
return self._jdf.stat().cov(col1, col2)
|
||||||
|
|
||||||
|
@ -1749,9 +1749,9 @@ class DataFrame(object):
|
||||||
:param col2: The name of the second column. Distinct items will make the column names
|
:param col2: The name of the second column. Distinct items will make the column names
|
||||||
of the DataFrame.
|
of the DataFrame.
|
||||||
"""
|
"""
|
||||||
if not isinstance(col1, str):
|
if not isinstance(col1, basestring):
|
||||||
raise ValueError("col1 should be a string.")
|
raise ValueError("col1 should be a string.")
|
||||||
if not isinstance(col2, str):
|
if not isinstance(col2, basestring):
|
||||||
raise ValueError("col2 should be a string.")
|
raise ValueError("col2 should be a string.")
|
||||||
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
|
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
|
||||||
|
|
||||||
|
|
|
@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
|
|
||||||
def test_approxQuantile(self):
|
def test_approxQuantile(self):
|
||||||
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
||||||
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
|
for f in ["a", u"a"]:
|
||||||
self.assertTrue(isinstance(aq, list))
|
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
|
||||||
self.assertEqual(len(aq), 3)
|
self.assertTrue(isinstance(aq, list))
|
||||||
|
self.assertEqual(len(aq), 3)
|
||||||
self.assertTrue(all(isinstance(q, float) for q in aq))
|
self.assertTrue(all(isinstance(q, float) for q in aq))
|
||||||
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
|
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
|
||||||
self.assertTrue(isinstance(aqs, list))
|
self.assertTrue(isinstance(aqs, list))
|
||||||
self.assertEqual(len(aqs), 2)
|
self.assertEqual(len(aqs), 2)
|
||||||
self.assertTrue(isinstance(aqs[0], list))
|
self.assertTrue(isinstance(aqs[0], list))
|
||||||
|
@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
self.assertTrue(isinstance(aqs[1], list))
|
self.assertTrue(isinstance(aqs[1], list))
|
||||||
self.assertEqual(len(aqs[1]), 3)
|
self.assertEqual(len(aqs[1]), 3)
|
||||||
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
||||||
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
|
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
|
||||||
self.assertTrue(isinstance(aqt, list))
|
self.assertTrue(isinstance(aqt, list))
|
||||||
self.assertEqual(len(aqt), 2)
|
self.assertEqual(len(aqt), 2)
|
||||||
self.assertTrue(isinstance(aqt[0], list))
|
self.assertTrue(isinstance(aqt[0], list))
|
||||||
|
@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase):
|
||||||
def test_corr(self):
|
def test_corr(self):
|
||||||
import math
|
import math
|
||||||
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
||||||
corr = df.stat.corr("a", "b")
|
corr = df.stat.corr(u"a", "b")
|
||||||
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
||||||
|
|
||||||
|
def test_sampleby(self):
|
||||||
|
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
|
||||||
|
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
|
||||||
|
self.assertTrue(sampled.count() == 3)
|
||||||
|
|
||||||
def test_cov(self):
|
def test_cov(self):
|
||||||
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
||||||
cov = df.stat.cov("a", "b")
|
cov = df.stat.cov(u"a", "b")
|
||||||
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
||||||
|
|
||||||
def test_crosstab(self):
|
def test_crosstab(self):
|
||||||
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
||||||
ct = df.stat.crosstab("a", "b").collect()
|
ct = df.stat.crosstab(u"a", "b").collect()
|
||||||
ct = sorted(ct, key=lambda x: x[0])
|
ct = sorted(ct, key=lambda x: x[0])
|
||||||
for i, row in enumerate(ct):
|
for i, row in enumerate(ct):
|
||||||
self.assertEqual(row[0], str(i))
|
self.assertEqual(row[0], str(i))
|
||||||
|
|
Loading…
Reference in a new issue