[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column

Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits:

b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column.  Removed ml.util.TestingUtils since VectorIndexer was the only use.
This commit is contained in:
Joseph K. Bradley 2015-04-29 16:35:17 -07:00
parent f8cbb0a4b3
commit b1ef6a60ff
3 changed files with 37 additions and 99 deletions

View file

@ -233,6 +233,7 @@ private object VectorIndexer {
* - Continuous features (columns) are left unchanged.
* This also appends metadata to the output column, marking features as Numeric (continuous),
* Nominal (categorical), or Binary (either continuous or categorical).
* Non-ML metadata is not carried over from the input to the output column.
*
* This maintains vector sparsity.
*
@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
// TODO: Check more carefully about whether this whole class will be included in a closure.
/** Per-vector transform function */
private val transformFunc: Vector => Vector = {
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
val localVectorMap = categoryMaps
val f: Vector => Vector = {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
val localNumFeatures = numFeatures
val f: Vector => Vector = { (v: Vector) =>
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length ${v.size}")
v match {
case dv: DenseVector =>
val tmpv = dv.copy
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
}
}
tmpv
tmpv
case sv: SparseVector =>
// We use the fact that categorical value 0 is always mapped to index 0.
val tmpv = sv.copy
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
var k = 0 // index into non-zero elements of sparse vector
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
if (featureIndex < tmpv.indices(k)) {
catFeatureIdx += 1
} else if (featureIndex > tmpv.indices(k)) {
k += 1
} else {
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
catFeatureIdx += 1
k += 1
}
}
tmpv
}
}
f
}
@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
val map = extractParamMap(paramMap)
val newField = prepOutputField(dataset.schema, map)
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
// For now, just check the first row of inputCol for vector length.
val firstRow = dataset.select(map(inputCol)).take(1)
if (firstRow.length != 0) {
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
s" $numFeatures but found length $actualNumFeatures")
}
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
}
@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
s"VectorIndexerModel requires output column parameter: $outputCol")
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
// If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
* Prepare the output column field, including per-feature metadata.
* @param schema Input schema
* @param map Parameter map (with this class' embedded parameter map folded in)
* @return Output column field
* @return Output column field. This field does not contain non-ML metadata.
*/
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
partialFeatureAttributes
}
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
newAttributeGroup.toStructField()
}
}

View file

@ -23,7 +23,6 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.util.TestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[IllegalArgumentException] {
model.transform(densePoints2)
intercept[SparkException] {
model.transform(densePoints2).collect()
println("Did not throw error when fit, transform were called on vectors of different lengths")
}
intercept[SparkException] {
@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
// TODO: Once input features marked as categorical are handled correctly, check that here.
}
}
// Check that non-ML metadata are preserved.
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
}
}

View file

@ -1,60 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import org.apache.spark.ml.Transformer
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.MetadataBuilder
import org.scalatest.FunSuite
private[ml] object TestingUtils extends FunSuite {
/**
* Test whether unrelated metadata are preserved for this transformer.
* This attaches extra metadata to a column, transforms the column, and check to ensure the
* extra metadata have not changed.
* @param data Input dataset
* @param transformer Transformer to test
* @param inputCol Unique input column for Transformer. This must be the ONLY input column.
* @param outputCol Output column to test for metadata presence.
*/
def testPreserveMetadata(
data: DataFrame,
transformer: Transformer,
inputCol: String,
outputCol: String): Unit = {
// Create some fake metadata
val origMetadata = data.schema(inputCol).metadata
val metaKey = "__testPreserveMetadata__fake_key"
val metaValue = 12345
assert(!origMetadata.contains(metaKey),
s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
val newMetadata =
new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
// Add metadata to the inputCol
val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
// Transform, and ensure extra metadata was not affected
val transformed = transformer.transform(withMetadata)
val transMetadata = transformed.schema(outputCol).metadata
assert(transMetadata.contains(metaKey),
"Unit test with testPreserveMetadata failed; extra metadata key was not present.")
assert(transMetadata.getLong(metaKey) === metaValue,
"Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
}
}