[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column

Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits:

b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.
This commit is contained in:
Joseph K. Bradley 2015-04-29 16:35:17 -07:00
parent f8cbb0a4b3
commit b1ef6a60ff
3 changed files with 37 additions and 99 deletions

View file

@ -233,6 +233,7 @@ private object VectorIndexer {
* - Continuous features (columns) are left unchanged. * - Continuous features (columns) are left unchanged.
* This also appends metadata to the output column, marking features as Numeric (continuous), * This also appends metadata to the output column, marking features as Numeric (continuous),
* Nominal (categorical), or Binary (either continuous or categorical). * Nominal (categorical), or Binary (either continuous or categorical).
* Non-ML metadata is not carried over from the input to the output column.
* *
* This maintains vector sparsity. * This maintains vector sparsity.
* *
@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
// TODO: Check more carefully about whether this whole class will be included in a closure. // TODO: Check more carefully about whether this whole class will be included in a closure.
/** Per-vector transform function */
private val transformFunc: Vector => Vector = { private val transformFunc: Vector => Vector = {
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps val localVectorMap = categoryMaps
val f: Vector => Vector = { val localNumFeatures = numFeatures
case dv: DenseVector => val f: Vector => Vector = { (v: Vector) =>
val tmpv = dv.copy assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => s" $numFeatures but found length ${v.size}")
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) v match {
} case dv: DenseVector =>
tmpv val tmpv = dv.copy
case sv: SparseVector => localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
// We use the fact that categorical value 0 is always mapped to index 0. tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
} }
} tmpv
tmpv case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
}
}
tmpv
}
} }
f f
} }
@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
val map = extractParamMap(paramMap) val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map) val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
// For now, just check the first row of inputCol for vector length.
val firstRow = dataset.select(map(inputCol)).take(1)
if (firstRow.length != 0) {
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length $actualNumFeatures")
}
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
} }
@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
s"VectorIndexerModel requires output column parameter: $outputCol") s"VectorIndexerModel requires output column parameter: $outputCol")
SchemaUtils.checkColumnType(schema, map(inputCol), dataType) SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
// If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length) Some(origAttrGroup.attributes.get.length)
@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
* Prepare the output column field, including per-feature metadata. * Prepare the output column field, including per-feature metadata.
* @param schema Input schema * @param schema Input schema
* @param map Parameter map (with this class' embedded parameter map folded in) * @param map Parameter map (with this class' embedded parameter map folded in)
* @return Output column field * @return Output column field. This field does not contain non-ML metadata.
*/ */
private def prepOutputField(schema: StructType, map: ParamMap): StructField = { private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
partialFeatureAttributes partialFeatureAttributes
} }
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
newAttributeGroup.toStructField(schema(map(inputCol)).metadata) newAttributeGroup.toStructField()
} }
} }

View file

@ -23,7 +23,6 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._ import org.apache.spark.ml.attribute._
import org.apache.spark.ml.util.TestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
val model = vectorIndexer.fit(densePoints1) // vectors of length 3 val model = vectorIndexer.fit(densePoints1) // vectors of length 3
model.transform(densePoints1) // should work model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work model.transform(sparsePoints1) // should work
intercept[IllegalArgumentException] { intercept[SparkException] {
model.transform(densePoints2) model.transform(densePoints2).collect()
println("Did not throw error when fit, transform were called on vectors of different lengths") println("Did not throw error when fit, transform were called on vectors of different lengths")
} }
intercept[SparkException] { intercept[SparkException] {
@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
// TODO: Once input features marked as categorical are handled correctly, check that here. // TODO: Once input features marked as categorical are handled correctly, check that here.
} }
} }
// Check that non-ML metadata are preserved.
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
} }
} }

View file

@ -1,60 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.ml.Transformer
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.MetadataBuilder
import org.scalatest.FunSuite
private[ml] object TestingUtils extends FunSuite {
/**
* Test whether unrelated metadata are preserved for this transformer.
* This attaches extra metadata to a column, transforms the column, and check to ensure the
* extra metadata have not changed.
* @param data Input dataset
* @param transformer Transformer to test
* @param inputCol Unique input column for Transformer. This must be the ONLY input column.
* @param outputCol Output column to test for metadata presence.
*/
def testPreserveMetadata(
data: DataFrame,
transformer: Transformer,
inputCol: String,
outputCol: String): Unit = {
// Create some fake metadata
val origMetadata = data.schema(inputCol).metadata
val metaKey = "__testPreserveMetadata__fake_key"
val metaValue = 12345
assert(!origMetadata.contains(metaKey),
s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
val newMetadata =
new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
// Add metadata to the inputCol
val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
// Transform, and ensure extra metadata was not affected
val transformed = transformer.transform(withMetadata)
val transMetadata = transformed.schema(outputCol).metadata
assert(transMetadata.contains(metaKey),
"Unit test with testPreserveMetadata failed; extra metadata key was not present.")
assert(transMetadata.getLong(metaKey) === metaValue,
"Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
}
}