[SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features
## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <lu.wang@databricks.com> Closes #21081 from ludatabricks/SPARK-23975.
This commit is contained in:
parent
55c4ca88a3
commit
2a24c481da
|
@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
|
|||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||
import org.apache.spark.sql.functions.udf
|
||||
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.VersionUtils.majorVersion
|
||||
|
||||
|
@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
|
|||
@Since("1.5.0")
|
||||
def getInitSteps: Int = $(initSteps)
|
||||
|
||||
/**
|
||||
* Validates the input schema.
|
||||
* @param schema input schema
|
||||
*/
|
||||
private[clustering] def validateSchema(schema: StructType): Unit = {
|
||||
val typeCandidates = List( new VectorUDT,
|
||||
new ArrayType(DoubleType, false),
|
||||
new ArrayType(FloatType, false))
|
||||
|
||||
SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
|
||||
}
|
||||
/**
|
||||
* Validates and transforms the input schema.
|
||||
* @param schema input schema
|
||||
* @return output schema
|
||||
*/
|
||||
protected def validateAndTransformSchema(schema: StructType): StructType = {
|
||||
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
|
||||
validateSchema(schema)
|
||||
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
|
||||
}
|
||||
}
|
||||
|
@ -125,8 +136,11 @@ class KMeansModel private[ml] (
|
|||
@Since("2.0.0")
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
val predictUDF = udf((vector: Vector) => predict(vector))
|
||||
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||
|
||||
dataset.withColumn($(predictionCol),
|
||||
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
|
@ -146,8 +160,10 @@ class KMeansModel private[ml] (
|
|||
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
|
||||
@Since("2.0.0")
|
||||
def computeCost(dataset: Dataset[_]): Double = {
|
||||
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
|
||||
val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
|
||||
validateSchema(dataset.schema)
|
||||
|
||||
val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol))
|
||||
.rdd.map {
|
||||
case Row(point: Vector) => OldVectors.fromML(point)
|
||||
}
|
||||
parentModel.computeCost(data)
|
||||
|
@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") (
|
|||
transformSchema(dataset.schema, logging = true)
|
||||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
|
||||
val instances: RDD[OldVector] = dataset.select(
|
||||
DatasetUtils.columnToVector(dataset, getFeaturesCol))
|
||||
.rdd.map {
|
||||
case Row(point: Vector) => OldVectors.fromML(point)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
|
||||
import org.apache.spark.sql.{Column, Dataset}
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
|
||||
|
||||
|
||||
private[spark] object DatasetUtils {
|
||||
|
||||
/**
|
||||
* Cast a column in a Dataset to Vector type.
|
||||
*
|
||||
* The supported data types of the input column are
|
||||
* - Vector
|
||||
* - float/double type Array.
|
||||
*
|
||||
* Note: The returned column does not have Metadata.
|
||||
*
|
||||
* @param dataset input DataFrame
|
||||
* @param colName column name.
|
||||
* @return Vector column
|
||||
*/
|
||||
def columnToVector(dataset: Dataset[_], colName: String): Column = {
|
||||
val columnDataType = dataset.schema(colName).dataType
|
||||
columnDataType match {
|
||||
case _: VectorUDT => col(colName)
|
||||
case fdt: ArrayType =>
|
||||
val transferUDF = fdt.elementType match {
|
||||
case _: FloatType => udf(f = (vector: Seq[Float]) => {
|
||||
val inputArray = Array.fill[Double](vector.size)(0.0)
|
||||
vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble)
|
||||
Vectors.dense(inputArray)
|
||||
})
|
||||
case _: DoubleType => udf((vector: Seq[Double]) => {
|
||||
Vectors.dense(vector.toArray)
|
||||
})
|
||||
case other =>
|
||||
throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector")
|
||||
}
|
||||
transferUDF(col(colName))
|
||||
case other =>
|
||||
throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
|
|||
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
|
||||
|
||||
private[clustering] case class TestRow(features: Vector)
|
||||
|
||||
|
@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
|
|||
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
|
||||
}
|
||||
|
||||
test("KMean with Array input") {
|
||||
val featuresColNameD = "array_double_features"
|
||||
val featuresColNameF = "array_float_features"
|
||||
|
||||
val doubleUDF = udf { (features: Vector) =>
|
||||
val featureArray = Array.fill[Double](features.size)(0.0)
|
||||
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
|
||||
featureArray
|
||||
}
|
||||
val floatUDF = udf { (features: Vector) =>
|
||||
val featureArray = Array.fill[Float](features.size)(0.0f)
|
||||
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
|
||||
featureArray
|
||||
}
|
||||
|
||||
val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features")))
|
||||
.drop("features")
|
||||
val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features")))
|
||||
.drop("features")
|
||||
assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false)))
|
||||
assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false)))
|
||||
|
||||
val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1)
|
||||
val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1)
|
||||
val modelD = kmeansD.fit(newdatasetD)
|
||||
val modelF = kmeansF.fit(newdatasetF)
|
||||
val transformedD = modelD.transform(newdatasetD)
|
||||
val transformedF = modelF.transform(newdatasetF)
|
||||
|
||||
val predictDifference = transformedD.select("prediction")
|
||||
.except(transformedF.select("prediction"))
|
||||
assert(predictDifference.count() == 0)
|
||||
assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )
|
||||
}
|
||||
|
||||
|
||||
test("read/write") {
|
||||
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
|
||||
assert(model.clusterCenters === model2.clusterCenters)
|
||||
|
|
Loading…
Reference in a new issue