[SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe
## What changes were proposed in this pull request? This PR proposes to support unicodes in Param methods in ML, other missed functions in DataFrame. For example, this causes a `ValueError` in Python 2.x when param is a unicode string: ```python >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() >>> lr.hasParam("threshold") True >>> lr.hasParam(u"threshold") Traceback (most recent call last): ... raise TypeError("hasParam(): paramName must be a string") TypeError: hasParam(): paramName must be a string ``` This PR is based on https://github.com/apache/spark/pull/13036 ## How was this patch tested? Unit tests in `python/pyspark/ml/tests.py` and `python/pyspark/sql/tests.py`. Author: hyukjinkwon <gurwls223@gmail.com> Author: sethah <seth.hendrickson16@gmail.com> Closes #17096 from HyukjinKwon/SPARK-15243.
This commit is contained in:
parent
8a4f228dc0
commit
8598d03a00
|
@ -330,7 +330,7 @@ class Params(Identifiable):
|
|||
Tests whether this instance contains a param with a given
|
||||
(string) name.
|
||||
"""
|
||||
if isinstance(paramName, str):
|
||||
if isinstance(paramName, basestring):
|
||||
p = getattr(self, paramName, None)
|
||||
return isinstance(p, Param)
|
||||
else:
|
||||
|
@ -413,7 +413,7 @@ class Params(Identifiable):
|
|||
if isinstance(param, Param):
|
||||
self._shouldOwn(param)
|
||||
return param
|
||||
elif isinstance(param, str):
|
||||
elif isinstance(param, basestring):
|
||||
return self.getParam(param)
|
||||
else:
|
||||
raise ValueError("Cannot resolve %r as a param." % param)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
|
@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase):
|
|||
testParams = TestParams()
|
||||
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
|
||||
self.assertFalse(testParams.hasParam("notAParameter"))
|
||||
self.assertTrue(testParams.hasParam(u"maxIter"))
|
||||
|
||||
def test_resolveparam(self):
|
||||
testParams = TestParams()
|
||||
self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
|
||||
self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
|
||||
|
||||
self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
|
||||
if sys.version_info[0] >= 3:
|
||||
# In Python 3, it is allowed to get/set attributes with non-ascii characters.
|
||||
e_cls = AttributeError
|
||||
else:
|
||||
e_cls = UnicodeEncodeError
|
||||
self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))
|
||||
|
||||
def test_params(self):
|
||||
testParams = TestParams()
|
||||
|
|
|
@ -748,7 +748,7 @@ class DataFrame(object):
|
|||
+---+-----+
|
||||
|
||||
"""
|
||||
if not isinstance(col, str):
|
||||
if not isinstance(col, basestring):
|
||||
raise ValueError("col must be a string, but got %r" % type(col))
|
||||
if not isinstance(fractions, dict):
|
||||
raise ValueError("fractions must be a dict but got %r" % type(fractions))
|
||||
|
@ -1664,18 +1664,18 @@ class DataFrame(object):
|
|||
Added support for multiple columns.
|
||||
"""
|
||||
|
||||
if not isinstance(col, (str, list, tuple)):
|
||||
if not isinstance(col, (basestring, list, tuple)):
|
||||
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
|
||||
|
||||
isStr = isinstance(col, str)
|
||||
isStr = isinstance(col, basestring)
|
||||
|
||||
if isinstance(col, tuple):
|
||||
col = list(col)
|
||||
elif isinstance(col, str):
|
||||
elif isStr:
|
||||
col = [col]
|
||||
|
||||
for c in col:
|
||||
if not isinstance(c, str):
|
||||
if not isinstance(c, basestring):
|
||||
raise ValueError("columns should be strings, but got %r" % type(c))
|
||||
col = _to_list(self._sc, col)
|
||||
|
||||
|
@ -1707,9 +1707,9 @@ class DataFrame(object):
|
|||
:param col2: The name of the second column
|
||||
:param method: The correlation method. Currently only supports "pearson"
|
||||
"""
|
||||
if not isinstance(col1, str):
|
||||
if not isinstance(col1, basestring):
|
||||
raise ValueError("col1 should be a string.")
|
||||
if not isinstance(col2, str):
|
||||
if not isinstance(col2, basestring):
|
||||
raise ValueError("col2 should be a string.")
|
||||
if not method:
|
||||
method = "pearson"
|
||||
|
@ -1727,9 +1727,9 @@ class DataFrame(object):
|
|||
:param col1: The name of the first column
|
||||
:param col2: The name of the second column
|
||||
"""
|
||||
if not isinstance(col1, str):
|
||||
if not isinstance(col1, basestring):
|
||||
raise ValueError("col1 should be a string.")
|
||||
if not isinstance(col2, str):
|
||||
if not isinstance(col2, basestring):
|
||||
raise ValueError("col2 should be a string.")
|
||||
return self._jdf.stat().cov(col1, col2)
|
||||
|
||||
|
@ -1749,9 +1749,9 @@ class DataFrame(object):
|
|||
:param col2: The name of the second column. Distinct items will make the column names
|
||||
of the DataFrame.
|
||||
"""
|
||||
if not isinstance(col1, str):
|
||||
if not isinstance(col1, basestring):
|
||||
raise ValueError("col1 should be a string.")
|
||||
if not isinstance(col2, str):
|
||||
if not isinstance(col2, basestring):
|
||||
raise ValueError("col2 should be a string.")
|
||||
return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
|
||||
|
||||
|
|
|
@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
|
||||
def test_approxQuantile(self):
|
||||
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
|
||||
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
|
||||
self.assertTrue(isinstance(aq, list))
|
||||
self.assertEqual(len(aq), 3)
|
||||
for f in ["a", u"a"]:
|
||||
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
|
||||
self.assertTrue(isinstance(aq, list))
|
||||
self.assertEqual(len(aq), 3)
|
||||
self.assertTrue(all(isinstance(q, float) for q in aq))
|
||||
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
|
||||
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
|
||||
self.assertTrue(isinstance(aqs, list))
|
||||
self.assertEqual(len(aqs), 2)
|
||||
self.assertTrue(isinstance(aqs[0], list))
|
||||
|
@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
self.assertTrue(isinstance(aqs[1], list))
|
||||
self.assertEqual(len(aqs[1]), 3)
|
||||
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
|
||||
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
|
||||
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
|
||||
self.assertTrue(isinstance(aqt, list))
|
||||
self.assertEqual(len(aqt), 2)
|
||||
self.assertTrue(isinstance(aqt[0], list))
|
||||
|
@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
def test_corr(self):
|
||||
import math
|
||||
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
|
||||
corr = df.stat.corr("a", "b")
|
||||
corr = df.stat.corr(u"a", "b")
|
||||
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
|
||||
|
||||
def test_sampleby(self):
|
||||
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
|
||||
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
|
||||
self.assertTrue(sampled.count() == 3)
|
||||
|
||||
def test_cov(self):
|
||||
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
|
||||
cov = df.stat.cov("a", "b")
|
||||
cov = df.stat.cov(u"a", "b")
|
||||
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
|
||||
|
||||
def test_crosstab(self):
|
||||
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
|
||||
ct = df.stat.crosstab("a", "b").collect()
|
||||
ct = df.stat.crosstab(u"a", "b").collect()
|
||||
ct = sorted(ct, key=lambda x: x[0])
|
||||
for i, row in enumerate(ct):
|
||||
self.assertEqual(row[0], str(i))
|
||||
|
|
Loading…
Reference in a new issue