[SPARK-30154][ML] PySpark UDF to convert MLlib vectors to dense arrays
### What changes were proposed in this pull request? PySpark UDF to convert MLlib vectors to dense arrays. Example: ``` from pyspark.ml.functions import vector_to_array df.select(vector_to_array(col("features")) ``` ### Why are the changes needed? If a PySpark user wants to convert MLlib sparse/dense vectors in a DataFrame into dense arrays, an efficient approach is to do that in JVM. However, it requires PySpark user to write Scala code and register it as a UDF. Often this is infeasible for a pure python project. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? UT. Closes #26910 from WeichenXu123/vector_to_array. Authored-by: WeichenXu <weichen.xu@databricks.com> Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
895e572b73
commit
88542bc3d9
|
@ -460,6 +460,7 @@ pyspark_ml = Module(
|
|||
"pyspark.ml.evaluation",
|
||||
"pyspark.ml.feature",
|
||||
"pyspark.ml.fpm",
|
||||
"pyspark.ml.functions",
|
||||
"pyspark.ml.image",
|
||||
"pyspark.ml.linalg.__init__",
|
||||
"pyspark.ml.recommendation",
|
||||
|
|
48
mllib/src/main/scala/org/apache/spark/ml/functions.scala
Normal file
48
mllib/src/main/scala/org/apache/spark/ml/functions.scala
Normal file
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector}
|
||||
import org.apache.spark.sql.Column
|
||||
import org.apache.spark.sql.functions.udf
|
||||
|
||||
// scalastyle:off
|
||||
@Since("3.0.0")
|
||||
object functions {
|
||||
// scalastyle:on
|
||||
|
||||
private val vectorToArrayUdf = udf { vec: Any =>
|
||||
vec match {
|
||||
case v: Vector => v.toArray
|
||||
case v: OldVector => v.toArray
|
||||
case v => throw new IllegalArgumentException(
|
||||
"function vector_to_array requires a non-null input argument and input type must be " +
|
||||
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
|
||||
s"but got ${ if (v == null) "null" else v.getClass.getName }.")
|
||||
}
|
||||
}.asNonNullable()
|
||||
|
||||
/**
|
||||
* Converts a column of MLlib sparse/dense vectors into a column of dense arrays.
|
||||
*
|
||||
* @since 3.0.0
|
||||
*/
|
||||
def vector_to_array(v: Column): Column = vectorToArrayUdf(v)
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.functions.vector_to_array
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.util.MLTest
|
||||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
|
||||
import org.apache.spark.sql.functions.col
|
||||
|
||||
class FunctionsSuite extends MLTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("test vector_to_array") {
|
||||
val df = Seq(
|
||||
(Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),
|
||||
(Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0))))
|
||||
).toDF("vec", "oldVec")
|
||||
|
||||
val result = df.select(vector_to_array('vec), vector_to_array('oldVec))
|
||||
.as[(Seq[Double], Seq[Double])]
|
||||
.collect().toSeq
|
||||
|
||||
val expected = Seq(
|
||||
(Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)),
|
||||
(Seq(2.0, 0.0, 3.0), Seq(20.0, 0.0, 30.0))
|
||||
)
|
||||
assert(result === expected)
|
||||
|
||||
val df2 = Seq(
|
||||
(Vectors.dense(1.0, 2.0, 3.0),
|
||||
OldVectors.dense(10.0, 20.0, 30.0), 1),
|
||||
(null, null, 0)
|
||||
).toDF("vec", "oldVec", "label")
|
||||
|
||||
|
||||
for ((colName, valType) <- Seq(
|
||||
("vec", "null"), ("oldVec", "null"), ("label", "java.lang.Integer"))) {
|
||||
val thrown1 = intercept[SparkException] {
|
||||
df2.select(vector_to_array(col(colName))).count
|
||||
}
|
||||
assert(thrown1.getCause.getMessage.contains(
|
||||
"function vector_to_array requires a non-null input argument and input type must be " +
|
||||
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
|
||||
s"but got ${valType}"))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -41,6 +41,14 @@ pyspark.ml.clustering module
|
|||
:undoc-members:
|
||||
:inherited-members:
|
||||
|
||||
pyspark.ml.functions module
|
||||
----------------------------
|
||||
|
||||
.. automodule:: pyspark.ml.functions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:inherited-members:
|
||||
|
||||
pyspark.ml.linalg module
|
||||
----------------------------
|
||||
|
||||
|
|
68
python/pyspark/ml/functions.py
Normal file
68
python/pyspark/ml/functions.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark import since, SparkContext
|
||||
from pyspark.sql.column import Column, _to_java_column
|
||||
|
||||
|
||||
@since(3.0)
|
||||
def vector_to_array(col):
|
||||
"""
|
||||
Converts a column of MLlib sparse/dense vectors into a column of dense arrays.
|
||||
|
||||
>>> from pyspark.ml.linalg import Vectors
|
||||
>>> from pyspark.ml.functions import vector_to_array
|
||||
>>> from pyspark.mllib.linalg import Vectors as OldVectors
|
||||
>>> df = spark.createDataFrame([
|
||||
... (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)),
|
||||
... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]),
|
||||
... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))],
|
||||
... ["vec", "oldVec"])
|
||||
>>> df.select(vector_to_array("vec").alias("vec"),
|
||||
... vector_to_array("oldVec").alias("oldVec")).collect()
|
||||
[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),
|
||||
Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(
|
||||
sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col)))
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.ml.functions
|
||||
import sys
|
||||
globs = pyspark.ml.functions.__dict__.copy()
|
||||
spark = SparkSession.builder \
|
||||
.master("local[2]") \
|
||||
.appName("ml.functions tests") \
|
||||
.getOrCreate()
|
||||
sc = spark.sparkContext
|
||||
globs['sc'] = sc
|
||||
globs['spark'] = spark
|
||||
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.ml.functions, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
Loading…
Reference in a new issue