[SPARK-35142][PYTHON][ML] Fix incorrect return type for rawPredictionUDF in OneVsRestModel

### What changes were proposed in this pull request?

Fixes incorrect return type for `rawPredictionUDF` in `OneVsRestModel`.

### Why are the changes needed?
Bugfix

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit test.

Closes #32245 from harupy/SPARK-35142.

Authored-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
harupy 2021-04-21 16:29:10 +08:00 committed by Weichen Xu
parent 43ad939a7e
commit b6350f5bb0
2 changed files with 15 additions and 3 deletions

View file

@ -40,7 +40,7 @@ from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
from pyspark.ml.wrapper import JavaParams, \
JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import Vectors
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
@ -3151,7 +3151,7 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
predArray.append(x)
return Vectors.dense(predArray)
rawPredictionUDF = udf(func)
rawPredictionUDF = udf(func, VectorUDT())
aggregatedDataset = aggregatedDataset.withColumn(
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))

View file

@ -25,7 +25,7 @@ from pyspark.ml.classification import FMClassifier, LogisticRegression, \
MultilayerPerceptronClassifier, OneVsRest
from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel
from pyspark.ml.fpm import FPGrowth
from pyspark.ml.linalg import Matrices, Vectors
from pyspark.ml.linalg import Matrices, Vectors, DenseVector
from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
from pyspark.sql import Row
@ -116,6 +116,18 @@ class OneVsRestTests(SparkSessionTestCase):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"])
def test_raw_prediction_column_is_of_vector_type(self):
# SPARK-35142: `OneVsRestModel` outputs raw prediction as a string column
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, parallelism=1)
model = ovr.fit(df)
row = model.transform(df).head()
self.assertIsInstance(row["rawPrediction"], DenseVector)
def test_parallelism_does_not_change_output(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),