[SPARK-35142][PYTHON][ML] Fix incorrect return type for rawPredictionUDF
in OneVsRestModel
### What changes were proposed in this pull request? Fixes incorrect return type for `rawPredictionUDF` in `OneVsRestModel`. ### Why are the changes needed? Bugfix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. Closes #32245 from harupy/SPARK-35142. Authored-by: harupy <17039389+harupy@users.noreply.github.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
43ad939a7e
commit
b6350f5bb0
|
@ -40,7 +40,7 @@ from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, \
|
|||
from pyspark.ml.wrapper import JavaParams, \
|
||||
JavaPredictor, JavaPredictionModel, JavaWrapper
|
||||
from pyspark.ml.common import inherit_doc
|
||||
from pyspark.ml.linalg import Vectors
|
||||
from pyspark.ml.linalg import Vectors, VectorUDT
|
||||
from pyspark.sql import DataFrame
|
||||
from pyspark.sql.functions import udf, when
|
||||
from pyspark.sql.types import ArrayType, DoubleType
|
||||
|
@ -3151,7 +3151,7 @@ class OneVsRestModel(Model, _OneVsRestParams, MLReadable, MLWritable):
|
|||
predArray.append(x)
|
||||
return Vectors.dense(predArray)
|
||||
|
||||
rawPredictionUDF = udf(func)
|
||||
rawPredictionUDF = udf(func, VectorUDT())
|
||||
aggregatedDataset = aggregatedDataset.withColumn(
|
||||
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from pyspark.ml.classification import FMClassifier, LogisticRegression, \
|
|||
MultilayerPerceptronClassifier, OneVsRest
|
||||
from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel
|
||||
from pyspark.ml.fpm import FPGrowth
|
||||
from pyspark.ml.linalg import Matrices, Vectors
|
||||
from pyspark.ml.linalg import Matrices, Vectors, DenseVector
|
||||
from pyspark.ml.recommendation import ALS
|
||||
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
|
||||
from pyspark.sql import Row
|
||||
|
@ -116,6 +116,18 @@ class OneVsRestTests(SparkSessionTestCase):
|
|||
output = model.transform(df)
|
||||
self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"])
|
||||
|
||||
def test_raw_prediction_column_is_of_vector_type(self):
|
||||
# SPARK-35142: `OneVsRestModel` outputs raw prediction as a string column
|
||||
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, Vectors.sparse(2, [], [])),
|
||||
(2.0, Vectors.dense(0.5, 0.5))],
|
||||
["label", "features"])
|
||||
lr = LogisticRegression(maxIter=5, regParam=0.01)
|
||||
ovr = OneVsRest(classifier=lr, parallelism=1)
|
||||
model = ovr.fit(df)
|
||||
row = model.transform(df).head()
|
||||
self.assertIsInstance(row["rawPrediction"], DenseVector)
|
||||
|
||||
def test_parallelism_does_not_change_output(self):
|
||||
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
|
||||
(1.0, Vectors.sparse(2, [], [])),
|
||||
|
|
Loading…
Reference in a new issue