[SPARK-16187][ML] Implement util method for ML Matrix conversion in scala/java
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16187 This is to provide conversion utils between old/new vector columns in a DataFrame. So users can use it to migrate their datasets and pipelines manually. ## How was this patch tested? java and scala ut Author: Yuhao Yang <yuhao.yang@intel.com> Closes #13888 from hhbyyh/matComp.
This commit is contained in:
parent
c48c8ebc0a
commit
c17b1abff8
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
|
|||
* User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL
|
||||
* via [[org.apache.spark.sql.Dataset]].
|
||||
*/
|
||||
private[ml] class MatrixUDT extends UserDefinedType[Matrix] {
|
||||
private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
|
||||
|
||||
override def sqlType: StructType = {
|
||||
// type: 0 = sparse, 1 = dense
|
||||
|
|
|
@ -23,7 +23,7 @@ import scala.reflect.ClassTag
|
|||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.linalg.{VectorUDT => MLVectorUDT}
|
||||
import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT}
|
||||
import org.apache.spark.mllib.linalg._
|
||||
import org.apache.spark.mllib.linalg.BLAS.dot
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
@ -309,8 +309,8 @@ object MLUtils extends Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Converts vector columns in an input Dataset to the [[org.apache.spark.ml.linalg.Vector]] type
|
||||
* from the new [[org.apache.spark.mllib.linalg.Vector]] type under the `spark.ml` package.
|
||||
* Converts vector columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Vector]]
|
||||
* type from the new [[org.apache.spark.ml.linalg.Vector]] type under the `spark.ml` package.
|
||||
* @param dataset input dataset
|
||||
* @param cols a list of vector columns to be converted. Old vector columns will be ignored. If
|
||||
* unspecified, all new vector columns will be converted except nested ones.
|
||||
|
@ -360,6 +360,107 @@ object MLUtils extends Logging {
|
|||
dataset.select(exprs: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts Matrix columns in an input Dataset from the [[org.apache.spark.mllib.linalg.Matrix]]
|
||||
* type to the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package.
|
||||
* @param dataset input dataset
|
||||
* @param cols a list of matrix columns to be converted. New matrix columns will be ignored. If
|
||||
* unspecified, all old matrix columns will be converted except nested ones.
|
||||
* @return the input [[DataFrame]] with old matrix columns converted to the new matrix type
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
@varargs
|
||||
def convertMatrixColumnsToML(dataset: Dataset[_], cols: String*): DataFrame = {
|
||||
val schema = dataset.schema
|
||||
val colSet = if (cols.nonEmpty) {
|
||||
cols.flatMap { c =>
|
||||
val dataType = schema(c).dataType
|
||||
if (dataType.getClass == classOf[MatrixUDT]) {
|
||||
Some(c)
|
||||
} else {
|
||||
// ignore new matrix columns and raise an exception on other column types
|
||||
require(dataType.getClass == classOf[MLMatrixUDT],
|
||||
s"Column $c must be old Matrix type to be converted to new type but got $dataType.")
|
||||
None
|
||||
}
|
||||
}.toSet
|
||||
} else {
|
||||
schema.fields
|
||||
.filter(_.dataType.getClass == classOf[MatrixUDT])
|
||||
.map(_.name)
|
||||
.toSet
|
||||
}
|
||||
|
||||
if (colSet.isEmpty) {
|
||||
return dataset.toDF()
|
||||
}
|
||||
|
||||
logWarning("Matrix column conversion has serialization overhead. " +
|
||||
"Please migrate your datasets and workflows to use the spark.ml package.")
|
||||
|
||||
val convertToML = udf { v: Matrix => v.asML }
|
||||
val exprs = schema.fields.map { field =>
|
||||
val c = field.name
|
||||
if (colSet.contains(c)) {
|
||||
convertToML(col(c)).as(c, field.metadata)
|
||||
} else {
|
||||
col(c)
|
||||
}
|
||||
}
|
||||
dataset.select(exprs: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts matrix columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Matrix]]
|
||||
* type from the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package.
|
||||
* @param dataset input dataset
|
||||
* @param cols a list of matrix columns to be converted. Old matrix columns will be ignored. If
|
||||
* unspecified, all new matrix columns will be converted except nested ones.
|
||||
* @return the input [[DataFrame]] with new matrix columns converted to the old matrix type
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
@varargs
|
||||
def convertMatrixColumnsFromML(dataset: Dataset[_], cols: String*): DataFrame = {
|
||||
val schema = dataset.schema
|
||||
val colSet = if (cols.nonEmpty) {
|
||||
cols.flatMap { c =>
|
||||
val dataType = schema(c).dataType
|
||||
if (dataType.getClass == classOf[MLMatrixUDT]) {
|
||||
Some(c)
|
||||
} else {
|
||||
// ignore old matrix columns and raise an exception on other column types
|
||||
require(dataType.getClass == classOf[MatrixUDT],
|
||||
s"Column $c must be new Matrix type to be converted to old type but got $dataType.")
|
||||
None
|
||||
}
|
||||
}.toSet
|
||||
} else {
|
||||
schema.fields
|
||||
.filter(_.dataType.getClass == classOf[MLMatrixUDT])
|
||||
.map(_.name)
|
||||
.toSet
|
||||
}
|
||||
|
||||
if (colSet.isEmpty) {
|
||||
return dataset.toDF()
|
||||
}
|
||||
|
||||
logWarning("Matrix column conversion has serialization overhead. " +
|
||||
"Please migrate your datasets and workflows to use the spark.ml package.")
|
||||
|
||||
val convertFromML = udf { Matrices.fromML _ }
|
||||
val exprs = schema.fields.map { field =>
|
||||
val c = field.name
|
||||
if (colSet.contains(c)) {
|
||||
convertFromML(col(c)).as(c, field.metadata)
|
||||
} else {
|
||||
col(c)
|
||||
}
|
||||
}
|
||||
dataset.select(exprs: _*)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the squared Euclidean distance between two vectors. The following formula will be used
|
||||
* if it does not introduce too much numerical error:
|
||||
|
|
|
@ -17,18 +17,22 @@
|
|||
|
||||
package org.apache.spark.mllib.util;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.SharedSparkSession;
|
||||
import org.apache.spark.mllib.linalg.Vector;
|
||||
import org.apache.spark.mllib.linalg.Vectors;
|
||||
import org.apache.spark.mllib.linalg.*;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.Metadata;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
public class JavaMLUtilsSuite extends SharedSparkSession {
|
||||
|
||||
|
@ -46,4 +50,25 @@ public class JavaMLUtilsSuite extends SharedSparkSession {
|
|||
Row old1 = MLUtils.convertVectorColumnsFromML(newDataset1).first();
|
||||
Assert.assertEquals(RowFactory.create(1.0, x), old1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConvertMatrixColumnsToAndFromML() {
|
||||
Matrix x = Matrices.dense(2, 1, new double[]{1.0, 2.0});
|
||||
StructType schema = new StructType(new StructField[]{
|
||||
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
|
||||
new StructField("features", new MatrixUDT(), false, Metadata.empty())
|
||||
});
|
||||
Dataset<Row> dataset = spark.createDataFrame(
|
||||
Arrays.asList(
|
||||
RowFactory.create(1.0, x)),
|
||||
schema);
|
||||
|
||||
Dataset<Row> newDataset1 = MLUtils.convertMatrixColumnsToML(dataset);
|
||||
Row new1 = newDataset1.first();
|
||||
Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1);
|
||||
Row new2 = MLUtils.convertMatrixColumnsToML(dataset, "features").first();
|
||||
Assert.assertEquals(new1, new2);
|
||||
Row old1 = MLUtils.convertMatrixColumnsFromML(newDataset1).first();
|
||||
Assert.assertEquals(RowFactory.create(1.0, x), old1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
|
|||
import com.google.common.io.Files
|
||||
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
|
||||
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.util.MLUtils._
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
|
@ -301,4 +301,58 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
convertVectorColumnsFromML(df, "p._2")
|
||||
}
|
||||
}
|
||||
|
||||
test("convertMatrixColumnsToML") {
|
||||
val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
|
||||
val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
|
||||
val y = Matrices.dense(2, 1, Array(0.2, 1.3))
|
||||
val z = Matrices.ones(1, 1)
|
||||
val p = (5.0, z)
|
||||
val w = Matrices.dense(1, 1, Array(4.5)).asML
|
||||
val df = spark.createDataFrame(Seq(
|
||||
(0, x, y, p, w)
|
||||
)).toDF("id", "x", "y", "p", "w")
|
||||
.withColumn("x", col("x"), metadata)
|
||||
val newDF1 = convertMatrixColumnsToML(df)
|
||||
assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
|
||||
val new1 = newDF1.first()
|
||||
assert(new1 === Row(0, x.asML, y.asML, Row(5.0, z), w))
|
||||
val new2 = convertMatrixColumnsToML(df, "x", "y").first()
|
||||
assert(new2 === new1)
|
||||
val new3 = convertMatrixColumnsToML(df, "y", "w").first()
|
||||
assert(new3 === Row(0, x, y.asML, Row(5.0, z), w))
|
||||
intercept[IllegalArgumentException] {
|
||||
convertMatrixColumnsToML(df, "p")
|
||||
}
|
||||
intercept[IllegalArgumentException] {
|
||||
convertMatrixColumnsToML(df, "p._2")
|
||||
}
|
||||
}
|
||||
|
||||
test("convertMatrixColumnsFromML") {
|
||||
val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)).asML
|
||||
val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
|
||||
val y = Matrices.dense(2, 1, Array(0.2, 1.3)).asML
|
||||
val z = Matrices.ones(1, 1).asML
|
||||
val p = (5.0, z)
|
||||
val w = Matrices.dense(1, 1, Array(4.5))
|
||||
val df = spark.createDataFrame(Seq(
|
||||
(0, x, y, p, w)
|
||||
)).toDF("id", "x", "y", "p", "w")
|
||||
.withColumn("x", col("x"), metadata)
|
||||
val newDF1 = convertMatrixColumnsFromML(df)
|
||||
assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
|
||||
val new1 = newDF1.first()
|
||||
assert(new1 === Row(0, Matrices.fromML(x), Matrices.fromML(y), Row(5.0, z), w))
|
||||
val new2 = convertMatrixColumnsFromML(df, "x", "y").first()
|
||||
assert(new2 === new1)
|
||||
val new3 = convertMatrixColumnsFromML(df, "y", "w").first()
|
||||
assert(new3 === Row(0, x, Matrices.fromML(y), Row(5.0, z), w))
|
||||
intercept[IllegalArgumentException] {
|
||||
convertMatrixColumnsFromML(df, "p")
|
||||
}
|
||||
intercept[IllegalArgumentException] {
|
||||
convertMatrixColumnsFromML(df, "p._2")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue