[SPARK-5895] [ML] Add VectorSlicer - updated

Add VectorSlicer transformer to spark.ml, with features specified as either indices or names.  Transfers feature attributes for selected features.

Updated version of [https://github.com/apache/spark/pull/5731]

CC: yinxusen This updates your PR.  You'll still be the primary author of this PR.

CC: mengxr

Author: Xusen Yin <yinxusen@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits:

b16e86e [Joseph K. Bradley] fixed scala style
71c65d2 [Joseph K. Bradley] fix import order
86e9739 [Joseph K. Bradley] cleanups per code review
9d8d6f1 [Joseph K. Bradley] style fix
83bc2e9 [Joseph K. Bradley] Updated VectorSlicer
98c6939 [Xusen Yin] fix style error
ecbf2d3 [Xusen Yin] change interfaces and params
f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895
e4781f2 [Xusen Yin] fix commit error
fd154d7 [Xusen Yin] add test suite of vector slicer
17171f8 [Xusen Yin] fix slicer
9ab9747 [Xusen Yin] add vector slicer
aa5a0bf [Xusen Yin] add vector slicer
This commit is contained in:
Xusen Yin 2015-08-05 17:07:55 -07:00 committed by Xiangrui Meng
parent 9c878923db
commit a018b85716
5 changed files with 327 additions and 0 deletions

View file

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
* This class takes a feature vector and outputs a new feature vector with a subarray of the
* original features.
*
* The subset of features can be specified with either indices ([[setIndices()]])
* or names ([[setNames()]]). At least one feature must be selected. Duplicate features
* are not allowed, so there can be no overlap between selected indices and names.
*
* The output vector will order features with the selected indices first (in the order given),
* followed by the selected names (in the order given).
*/
@Experimental
final class VectorSlicer(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("vectorSlicer"))
/**
* An array of indices to select features from a vector column.
* There can be no overlap with [[names]].
* @group param
*/
val indices = new IntArrayParam(this, "indices",
"An array of indices to select features from a vector column." +
" There can be no overlap with names.", VectorSlicer.validIndices)
setDefault(indices -> Array.empty[Int])
/** @group getParam */
def getIndices: Array[Int] = $(indices)
/** @group setParam */
def setIndices(value: Array[Int]): this.type = set(indices, value)
/**
* An array of feature names to select features from a vector column.
* These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
* There can be no overlap with [[indices]].
* @group param
*/
val names = new StringArrayParam(this, "names",
"An array of feature names to select features from a vector column." +
" There can be no overlap with indices.", VectorSlicer.validNames)
setDefault(names -> Array.empty[String])
/** @group getParam */
def getNames: Array[String] = $(names)
/** @group setParam */
def setNames(value: Array[String]): this.type = set(names, value)
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def validateParams(): Unit = {
require($(indices).length > 0 || $(names).length > 0,
s"VectorSlicer requires that at least one feature be selected.")
}
override def transform(dataset: DataFrame): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
inputAttr.numAttributes.foreach { numFeatures =>
val maxIndex = $(indices).max
require(maxIndex < numFeatures,
s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
}
// Prepare output attributes
val inds = getSelectedFeatureIndices(dataset.schema)
val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
inds.map(index => attrs(index))
}
val outputAttr = selectedAttrs match {
case Some(attrs) => new AttributeGroup($(outputCol), attrs)
case None => new AttributeGroup($(outputCol), inds.length)
}
// Select features
val slicer = udf { vec: Vector =>
vec match {
case features: DenseVector => Vectors.dense(inds.map(features.apply))
case features: SparseVector => features.slice(inds)
}
}
dataset.withColumn($(outputCol),
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
}
/** Get the feature indices in order: indices, names */
private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
val indFeatures = $(indices)
val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
s" sets of features, but they overlap." +
s" indices: ${indFeatures.mkString("[", ",", "]")}." +
s" names: " +
nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
indFeatures ++ nameFeatures
}
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
if (schema.fieldNames.contains($(outputCol))) {
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
}
val numFeaturesSelected = $(indices).length + $(names).length
val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
val outputFields = schema.fields :+ outputAttr.toStructField()
StructType(outputFields)
}
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
}
private[feature] object VectorSlicer {
/** Return true if given feature indices are valid */
def validIndices(indices: Array[Int]): Boolean = {
if (indices.isEmpty) {
true
} else {
indices.length == indices.distinct.length && indices.forall(_ >= 0)
}
}
/** Return true if given feature names are valid */
def validNames(names: Array[String]): Boolean = {
names.forall(_.nonEmpty) && names.length == names.distinct.length
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap
import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.StructField
@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
}
}
/**
* Takes a Vector column and a list of feature names, and returns the corresponding list of
* feature indices in the column, in order.
* @param col Vector column which must have feature names specified via attributes
* @param names List of feature names
*/
def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
+ s" to be Vector type, but it was type ${col.dataType} instead.")
val inputAttr = AttributeGroup.fromStructField(col)
names.map { name =>
require(inputAttr.hasAttr(name),
s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
inputAttr.getAttr(name).index.get
}
}
}

View file

@ -766,6 +766,30 @@ class SparseVector(
maxIdx
}
}
/**
* Create a slice of this vector based on the given indices.
* @param selectedIndices Unsorted list of indices into the vector.
* This does NOT do bound checking.
* @return New SparseVector with values in the order specified by the given indices.
*
* NOTE: The API needs to be discussed before making this public.
* Also, if we have a version assuming indices are sorted, we should optimize it.
*/
private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
var currentIdx = 0
val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
val i_v = if (iIdx >= 0) {
Iterator((currentIdx, this.values(iIdx)))
} else {
Iterator()
}
currentIdx += 1
i_v
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
}
}
object SparseVector {

View file

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
val slicer = new VectorSlicer
ParamsSuite.checkParams(slicer)
assert(slicer.getIndices.length === 0)
assert(slicer.getNames.length === 0)
withClue("VectorSlicer should not have any features selected by default") {
intercept[IllegalArgumentException] {
slicer.validateParams()
}
}
}
test("feature validity checks") {
import VectorSlicer._
assert(validIndices(Array(0, 1, 8, 2)))
assert(validIndices(Array.empty[Int]))
assert(!validIndices(Array(-1)))
assert(!validIndices(Array(1, 2, 1)))
assert(validNames(Array("a", "b")))
assert(validNames(Array.empty[String]))
assert(!validNames(Array("", "b")))
assert(!validNames(Array("a", "b", "a")))
}
test("Test vector slicer") {
val sqlContext = new SQLContext(sc)
val data = Array(
Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
Vectors.sparse(5, Seq())
)
// Expected after selecting indices 1, 4
val expected = Array(
Vectors.sparse(2, Seq((0, 2.3))),
Vectors.dense(2.3, 1.0),
Vectors.dense(0.0, 0.0),
Vectors.dense(-1.1, 3.3),
Vectors.sparse(2, Seq())
)
val defaultAttr = NumericAttribute.defaultAttr
val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
val df = sqlContext.createDataFrame(rdd,
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
def validateResults(df: DataFrame): Unit = {
df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 === vec2)
}
val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
assert(a === b)
}
}
vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
validateResults(vectorSlicer.transform(df))
vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
validateResults(vectorSlicer.transform(df))
vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
validateResults(vectorSlicer.transform(df))
}
}

View file

@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
assert(sv1 === sv1c)
}
test("SparseVector.slice") {
val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
}
}