[SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column
Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits: b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use.
This commit is contained in:
parent
f8cbb0a4b3
commit
b1ef6a60ff
|
@ -233,6 +233,7 @@ private object VectorIndexer {
|
|||
* - Continuous features (columns) are left unchanged.
|
||||
* This also appends metadata to the output column, marking features as Numeric (continuous),
|
||||
* Nominal (categorical), or Binary (either continuous or categorical).
|
||||
* Non-ML metadata is not carried over from the input to the output column.
|
||||
*
|
||||
* This maintains vector sparsity.
|
||||
*
|
||||
|
@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] (
|
|||
|
||||
// TODO: Check more carefully about whether this whole class will be included in a closure.
|
||||
|
||||
/** Per-vector transform function */
|
||||
private val transformFunc: Vector => Vector = {
|
||||
val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted
|
||||
val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
|
||||
val localVectorMap = categoryMaps
|
||||
val f: Vector => Vector = {
|
||||
case dv: DenseVector =>
|
||||
val tmpv = dv.copy
|
||||
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
|
||||
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
|
||||
}
|
||||
tmpv
|
||||
case sv: SparseVector =>
|
||||
// We use the fact that categorical value 0 is always mapped to index 0.
|
||||
val tmpv = sv.copy
|
||||
var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices
|
||||
var k = 0 // index into non-zero elements of sparse vector
|
||||
while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) {
|
||||
val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx)
|
||||
if (featureIndex < tmpv.indices(k)) {
|
||||
catFeatureIdx += 1
|
||||
} else if (featureIndex > tmpv.indices(k)) {
|
||||
k += 1
|
||||
} else {
|
||||
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
|
||||
catFeatureIdx += 1
|
||||
k += 1
|
||||
val localNumFeatures = numFeatures
|
||||
val f: Vector => Vector = { (v: Vector) =>
|
||||
assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
|
||||
s" $numFeatures but found length ${v.size}")
|
||||
v match {
|
||||
case dv: DenseVector =>
|
||||
val tmpv = dv.copy
|
||||
localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) =>
|
||||
tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
|
||||
}
|
||||
}
|
||||
tmpv
|
||||
tmpv
|
||||
case sv: SparseVector =>
|
||||
// We use the fact that categorical value 0 is always mapped to index 0.
|
||||
val tmpv = sv.copy
|
||||
var catFeatureIdx = 0 // index into sortedCatFeatureIndices
|
||||
var k = 0 // index into non-zero elements of sparse vector
|
||||
while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) {
|
||||
val featureIndex = sortedCatFeatureIndices(catFeatureIdx)
|
||||
if (featureIndex < tmpv.indices(k)) {
|
||||
catFeatureIdx += 1
|
||||
} else if (featureIndex > tmpv.indices(k)) {
|
||||
k += 1
|
||||
} else {
|
||||
tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
|
||||
catFeatureIdx += 1
|
||||
k += 1
|
||||
}
|
||||
}
|
||||
tmpv
|
||||
}
|
||||
}
|
||||
f
|
||||
}
|
||||
|
@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] (
|
|||
val map = extractParamMap(paramMap)
|
||||
val newField = prepOutputField(dataset.schema, map)
|
||||
val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol)))
|
||||
// For now, just check the first row of inputCol for vector length.
|
||||
val firstRow = dataset.select(map(inputCol)).take(1)
|
||||
if (firstRow.length != 0) {
|
||||
val actualNumFeatures = firstRow(0).getAs[Vector](0).size
|
||||
require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" +
|
||||
s" $numFeatures but found length $actualNumFeatures")
|
||||
}
|
||||
dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata))
|
||||
}
|
||||
|
||||
|
@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] (
|
|||
s"VectorIndexerModel requires output column parameter: $outputCol")
|
||||
SchemaUtils.checkColumnType(schema, map(inputCol), dataType)
|
||||
|
||||
// If the input metadata specifies numFeatures, compare with expected numFeatures.
|
||||
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
||||
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
|
||||
Some(origAttrGroup.attributes.get.length)
|
||||
|
@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] (
|
|||
* Prepare the output column field, including per-feature metadata.
|
||||
* @param schema Input schema
|
||||
* @param map Parameter map (with this class' embedded parameter map folded in)
|
||||
* @return Output column field
|
||||
* @return Output column field. This field does not contain non-ML metadata.
|
||||
*/
|
||||
private def prepOutputField(schema: StructType, map: ParamMap): StructField = {
|
||||
val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol)))
|
||||
|
@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] (
|
|||
partialFeatureAttributes
|
||||
}
|
||||
val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes)
|
||||
newAttributeGroup.toStructField(schema(map(inputCol)).metadata)
|
||||
newAttributeGroup.toStructField()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.scalatest.FunSuite
|
|||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.util.TestingUtils
|
||||
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
|
|||
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
|
||||
model.transform(densePoints1) // should work
|
||||
model.transform(sparsePoints1) // should work
|
||||
intercept[IllegalArgumentException] {
|
||||
model.transform(densePoints2)
|
||||
intercept[SparkException] {
|
||||
model.transform(densePoints2).collect()
|
||||
println("Did not throw error when fit, transform were called on vectors of different lengths")
|
||||
}
|
||||
intercept[SparkException] {
|
||||
|
@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
|
|||
// TODO: Once input features marked as categorical are handled correctly, check that here.
|
||||
}
|
||||
}
|
||||
// Check that non-ML metadata are preserved.
|
||||
TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.util
|
||||
|
||||
import org.apache.spark.ml.Transformer
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.MetadataBuilder
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
private[ml] object TestingUtils extends FunSuite {
|
||||
|
||||
/**
|
||||
* Test whether unrelated metadata are preserved for this transformer.
|
||||
* This attaches extra metadata to a column, transforms the column, and check to ensure the
|
||||
* extra metadata have not changed.
|
||||
* @param data Input dataset
|
||||
* @param transformer Transformer to test
|
||||
* @param inputCol Unique input column for Transformer. This must be the ONLY input column.
|
||||
* @param outputCol Output column to test for metadata presence.
|
||||
*/
|
||||
def testPreserveMetadata(
|
||||
data: DataFrame,
|
||||
transformer: Transformer,
|
||||
inputCol: String,
|
||||
outputCol: String): Unit = {
|
||||
// Create some fake metadata
|
||||
val origMetadata = data.schema(inputCol).metadata
|
||||
val metaKey = "__testPreserveMetadata__fake_key"
|
||||
val metaValue = 12345
|
||||
assert(!origMetadata.contains(metaKey),
|
||||
s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey")
|
||||
val newMetadata =
|
||||
new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build()
|
||||
// Add metadata to the inputCol
|
||||
val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata))
|
||||
// Transform, and ensure extra metadata was not affected
|
||||
val transformed = transformer.transform(withMetadata)
|
||||
val transMetadata = transformed.schema(outputCol).metadata
|
||||
assert(transMetadata.contains(metaKey),
|
||||
"Unit test with testPreserveMetadata failed; extra metadata key was not present.")
|
||||
assert(transMetadata.getLong(metaKey) === metaValue,
|
||||
"Unit test with testPreserveMetadata failed; extra metadata value was wrong." +
|
||||
s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue