[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column
Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits: b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use.
This commit is contained in:
parent
f8cbb0a4b3
commit
b1ef6a60ff
|
@ -233,6 +233,7 @@ private object VectorIndexer {
|
||||||
* - Continuous features (columns) are left unchanged.
|
* - Continuous features (columns) are left unchanged.
|
||||||
* This also appends metadata to the output column, marking features as Numeric (continuous),
|
* This also appends metadata to the output column, marking features as Numeric (continuous),
|
||||||
* Nominal (categorical), or Binary (either continuous or categorical).
|
* Nominal (categorical), or Binary (either continuous or categorical).
|
||||||
|
* Non-ML metadata is not carried over from the input to the output column.
|
||||||
*
|
*
|
||||||
* This maintains vector sparsity.
|
* This maintains vector sparsity.
|
||||||
*
|
*
|
||||||
|
@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
|
||||||
|
|
||||||
// TODO: Check more carefully about whether this whole class will be included in a closure.
|
// TODO: Check more carefully about whether this whole class will be included in a closure.
|
||||||
|
|
||||||
|
/** Per-vector transform function */
|
||||||
private val transformFunc: Vector => Vector = {
|
private val transformFunc: Vector => Vector = {
|
||||||
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
|
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
|
||||||
val localVectorMap = categoryMaps
|
val localVectorMap = categoryMaps
|
||||||
val f: Vector => Vector = {
|
val localNumFeatures = numFeatures
|
||||||
case dv: DenseVector =>
|
val f: Vector => Vector = { (v: Vector) =>
|
||||||
val tmpv = dv.copy
|
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
|
||||||
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
|
s" $numFeatures but found length ${v.size}")
|
||||||
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
|
v match {
|
||||||
}
|
case dv: DenseVector =>
|
||||||
tmpv
|
val tmpv = dv.copy
|
||||||
case sv: SparseVector =>
|
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
|
||||||
// We use the fact that categorical value 0 is always mapped to index 0.
|
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
|
||||||
val tmpv = sv.copy
|
|
||||||
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
|
|
||||||
var k = 0 // index into non-zero elements of sparse vector
|
|
||||||
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
|
|
||||||
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
|
|
||||||
if (featureIndex < tmpv.indices(k)) {
|
|
||||||
catFeatureIdx += 1
|
|
||||||
} else if (featureIndex > tmpv.indices(k)) {
|
|
||||||
k += 1
|
|
||||||
} else {
|
|
||||||
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
|
|
||||||
catFeatureIdx += 1
|
|
||||||
k += 1
|
|
||||||
}
|
}
|
||||||
}
|
tmpv
|
||||||
tmpv
|
case sv: SparseVector =>
|
||||||
|
// We use the fact that categorical value 0 is always mapped to index 0.
|
||||||
|
val tmpv = sv.copy
|
||||||
|
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
|
||||||
|
var k = 0 // index into non-zero elements of sparse vector
|
||||||
|
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
|
||||||
|
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
|
||||||
|
if (featureIndex < tmpv.indices(k)) {
|
||||||
|
catFeatureIdx += 1
|
||||||
|
} else if (featureIndex > tmpv.indices(k)) {
|
||||||
|
k += 1
|
||||||
|
} else {
|
||||||
|
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
|
||||||
|
catFeatureIdx += 1
|
||||||
|
k += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tmpv
|
||||||
|
}
|
||||||
}
|
}
|
||||||
f
|
f
|
||||||
}
|
}
|
||||||
|
@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
|
||||||
val map = extractParamMap(paramMap)
|
val map = extractParamMap(paramMap)
|
||||||
val newField = prepOutputField(dataset.schema, map)
|
val newField = prepOutputField(dataset.schema, map)
|
||||||
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
|
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
|
||||||
// For now, just check the first row of inputCol for vector length.
|
|
||||||
val firstRow = dataset.select(map(inputCol)).take(1)
|
|
||||||
if (firstRow.length != 0) {
|
|
||||||
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
|
|
||||||
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
|
|
||||||
s" $numFeatures but found length $actualNumFeatures")
|
|
||||||
}
|
|
||||||
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
|
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
|
||||||
s"VectorIndexerModel requires output column parameter: $outputCol")
|
s"VectorIndexerModel requires output column parameter: $outputCol")
|
||||||
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
|
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
|
||||||
|
|
||||||
|
// If the input metadata specifies numFeatures, compare with expected numFeatures.
|
||||||
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
||||||
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
|
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
|
||||||
Some(origAttrGroup.attributes.get.length)
|
Some(origAttrGroup.attributes.get.length)
|
||||||
|
@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
|
||||||
* Prepare the output column field, including per-feature metadata.
|
* Prepare the output column field, including per-feature metadata.
|
||||||
* @param schema Input schema
|
* @param schema Input schema
|
||||||
* @param map Parameter map (with this class' embedded parameter map folded in)
|
* @param map Parameter map (with this class' embedded parameter map folded in)
|
||||||
* @return Output column field
|
* @return Output column field. This field does not contain non-ML metadata.
|
||||||
*/
|
*/
|
||||||
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
|
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
|
||||||
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
||||||
|
@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
|
||||||
partialFeatureAttributes
|
partialFeatureAttributes
|
||||||
}
|
}
|
||||||
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
|
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
|
||||||
newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
|
newAttributeGroup.toStructField()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.ml.attribute._
|
import org.apache.spark.ml.attribute._
|
||||||
import org.apache.spark.ml.util.TestingUtils
|
|
||||||
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
|
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
|
||||||
model.transform(densePoints1) // should work
|
model.transform(densePoints1) // should work
|
||||||
model.transform(sparsePoints1) // should work
|
model.transform(sparsePoints1) // should work
|
||||||
intercept[IllegalArgumentException] {
|
intercept[SparkException] {
|
||||||
model.transform(densePoints2)
|
model.transform(densePoints2).collect()
|
||||||
println("Did not throw error when fit, transform were called on vectors of different lengths")
|
println("Did not throw error when fit, transform were called on vectors of different lengths")
|
||||||
}
|
}
|
||||||
intercept[SparkException] {
|
intercept[SparkException] {
|
||||||
|
@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
|
||||||
// TODO: Once input features marked as categorical are handled correctly, check that here.
|
// TODO: Once input features marked as categorical are handled correctly, check that here.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check that non-ML metadata are preserved.
|
|
||||||
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark.ml.util
|
|
||||||
|
|
||||||
import org.apache.spark.ml.Transformer
|
|
||||||
import org.apache.spark.sql.DataFrame
|
|
||||||
import org.apache.spark.sql.types.MetadataBuilder
|
|
||||||
import org.scalatest.FunSuite
|
|
||||||
|
|
||||||
private[ml] object TestingUtils extends FunSuite {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Test whether unrelated metadata are preserved for this transformer.
|
|
||||||
* This attaches extra metadata to a column, transforms the column, and check to ensure the
|
|
||||||
* extra metadata have not changed.
|
|
||||||
* @param data Input dataset
|
|
||||||
* @param transformer Transformer to test
|
|
||||||
* @param inputCol Unique input column for Transformer. This must be the ONLY input column.
|
|
||||||
* @param outputCol Output column to test for metadata presence.
|
|
||||||
*/
|
|
||||||
def testPreserveMetadata(
|
|
||||||
data: DataFrame,
|
|
||||||
transformer: Transformer,
|
|
||||||
inputCol: String,
|
|
||||||
outputCol: String): Unit = {
|
|
||||||
// Create some fake metadata
|
|
||||||
val origMetadata = data.schema(inputCol).metadata
|
|
||||||
val metaKey = "__testPreserveMetadata__fake_key"
|
|
||||||
val metaValue = 12345
|
|
||||||
assert(!origMetadata.contains(metaKey),
|
|
||||||
s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
|
|
||||||
val newMetadata =
|
|
||||||
new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
|
|
||||||
// Add metadata to the inputCol
|
|
||||||
val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
|
|
||||||
// Transform, and ensure extra metadata was not affected
|
|
||||||
val transformed = transformer.transform(withMetadata)
|
|
||||||
val transMetadata = transformed.schema(outputCol).metadata
|
|
||||||
assert(transMetadata.contains(metaKey),
|
|
||||||
"Unit test with testPreserveMetadata failed; extra metadata key was not present.")
|
|
||||||
assert(transMetadata.getLong(metaKey) === metaValue,
|
|
||||||
"Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
|
|
||||||
s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in a new issue