[SPARK-5895] [ML] Add VectorSlicer - updated
Add VectorSlicer transformer to spark.ml, with features specified as either indices or names. Transfers feature attributes for selected features. Updated version of [https://github.com/apache/spark/pull/5731] CC: yinxusen This updates your PR. You'll still be the primary author of this PR. CC: mengxr Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits: b16e86e [Joseph K. Bradley] fixed scala style 71c65d2 [Joseph K. Bradley] fix import order 86e9739 [Joseph K. Bradley] cleanups per code review 9d8d6f1 [Joseph K. Bradley] style fix 83bc2e9 [Joseph K. Bradley] Updated VectorSlicer 98c6939 [Xusen Yin] fix style error ecbf2d3 [Xusen Yin] change interfaces and params f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895 e4781f2 [Xusen Yin] fix commit error fd154d7 [Xusen Yin] add test suite of vector slicer 17171f8 [Xusen Yin] fix slicer 9ab9747 [Xusen Yin] add vector slicer aa5a0bf [Xusen Yin] add vector slicer
This commit is contained in:
parent
9c878923db
commit
a018b85716
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.ml.Transformer
|
||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
|
||||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
|
||||
import org.apache.spark.mllib.linalg._
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* This class takes a feature vector and outputs a new feature vector with a subarray of the
|
||||
* original features.
|
||||
*
|
||||
* The subset of features can be specified with either indices ([[setIndices()]])
|
||||
* or names ([[setNames()]]). At least one feature must be selected. Duplicate features
|
||||
* are not allowed, so there can be no overlap between selected indices and names.
|
||||
*
|
||||
* The output vector will order features with the selected indices first (in the order given),
|
||||
* followed by the selected names (in the order given).
|
||||
*/
|
||||
@Experimental
|
||||
final class VectorSlicer(override val uid: String)
|
||||
extends Transformer with HasInputCol with HasOutputCol {
|
||||
|
||||
def this() = this(Identifiable.randomUID("vectorSlicer"))
|
||||
|
||||
/**
|
||||
* An array of indices to select features from a vector column.
|
||||
* There can be no overlap with [[names]].
|
||||
* @group param
|
||||
*/
|
||||
val indices = new IntArrayParam(this, "indices",
|
||||
"An array of indices to select features from a vector column." +
|
||||
" There can be no overlap with names.", VectorSlicer.validIndices)
|
||||
|
||||
setDefault(indices -> Array.empty[Int])
|
||||
|
||||
/** @group getParam */
|
||||
def getIndices: Array[Int] = $(indices)
|
||||
|
||||
/** @group setParam */
|
||||
def setIndices(value: Array[Int]): this.type = set(indices, value)
|
||||
|
||||
/**
|
||||
* An array of feature names to select features from a vector column.
|
||||
* These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
|
||||
* There can be no overlap with [[indices]].
|
||||
* @group param
|
||||
*/
|
||||
val names = new StringArrayParam(this, "names",
|
||||
"An array of feature names to select features from a vector column." +
|
||||
" There can be no overlap with indices.", VectorSlicer.validNames)
|
||||
|
||||
setDefault(names -> Array.empty[String])
|
||||
|
||||
/** @group getParam */
|
||||
def getNames: Array[String] = $(names)
|
||||
|
||||
/** @group setParam */
|
||||
def setNames(value: Array[String]): this.type = set(names, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setInputCol(value: String): this.type = set(inputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
override def validateParams(): Unit = {
|
||||
require($(indices).length > 0 || $(names).length > 0,
|
||||
s"VectorSlicer requires that at least one feature be selected.")
|
||||
}
|
||||
|
||||
override def transform(dataset: DataFrame): DataFrame = {
|
||||
// Validity checks
|
||||
transformSchema(dataset.schema)
|
||||
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
|
||||
inputAttr.numAttributes.foreach { numFeatures =>
|
||||
val maxIndex = $(indices).max
|
||||
require(maxIndex < numFeatures,
|
||||
s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
|
||||
}
|
||||
|
||||
// Prepare output attributes
|
||||
val inds = getSelectedFeatureIndices(dataset.schema)
|
||||
val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
|
||||
inds.map(index => attrs(index))
|
||||
}
|
||||
val outputAttr = selectedAttrs match {
|
||||
case Some(attrs) => new AttributeGroup($(outputCol), attrs)
|
||||
case None => new AttributeGroup($(outputCol), inds.length)
|
||||
}
|
||||
|
||||
// Select features
|
||||
val slicer = udf { vec: Vector =>
|
||||
vec match {
|
||||
case features: DenseVector => Vectors.dense(inds.map(features.apply))
|
||||
case features: SparseVector => features.slice(inds)
|
||||
}
|
||||
}
|
||||
dataset.withColumn($(outputCol),
|
||||
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
|
||||
}
|
||||
|
||||
/** Get the feature indices in order: indices, names */
|
||||
private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
|
||||
val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
|
||||
val indFeatures = $(indices)
|
||||
val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
|
||||
lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
|
||||
s" sets of features, but they overlap." +
|
||||
s" indices: ${indFeatures.mkString("[", ",", "]")}." +
|
||||
s" names: " +
|
||||
nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
|
||||
require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
|
||||
indFeatures ++ nameFeatures
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
|
||||
|
||||
if (schema.fieldNames.contains($(outputCol))) {
|
||||
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
|
||||
}
|
||||
val numFeaturesSelected = $(indices).length + $(names).length
|
||||
val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
|
||||
val outputFields = schema.fields :+ outputAttr.toStructField()
|
||||
StructType(outputFields)
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
|
||||
}
|
||||
|
||||
private[feature] object VectorSlicer {
|
||||
|
||||
/** Return true if given feature indices are valid */
|
||||
def validIndices(indices: Array[Int]): Boolean = {
|
||||
if (indices.isEmpty) {
|
||||
true
|
||||
} else {
|
||||
indices.length == indices.distinct.length && indices.forall(_ >= 0)
|
||||
}
|
||||
}
|
||||
|
||||
/** Return true if given feature names are valid */
|
||||
def validNames(names: Array[String]): Boolean = {
|
||||
names.forall(_.nonEmpty) && names.length == names.distinct.length
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ package org.apache.spark.ml.util
|
|||
import scala.collection.immutable.HashMap
|
||||
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.mllib.linalg.VectorUDT
|
||||
import org.apache.spark.sql.types.StructField
|
||||
|
||||
|
||||
|
@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes a Vector column and a list of feature names, and returns the corresponding list of
|
||||
* feature indices in the column, in order.
|
||||
* @param col Vector column which must have feature names specified via attributes
|
||||
* @param names List of feature names
|
||||
*/
|
||||
def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
|
||||
require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
|
||||
+ s" to be Vector type, but it was type ${col.dataType} instead.")
|
||||
val inputAttr = AttributeGroup.fromStructField(col)
|
||||
names.map { name =>
|
||||
require(inputAttr.hasAttr(name),
|
||||
s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
|
||||
inputAttr.getAttr(name).index.get
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -766,6 +766,30 @@ class SparseVector(
|
|||
maxIdx
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a slice of this vector based on the given indices.
|
||||
* @param selectedIndices Unsorted list of indices into the vector.
|
||||
* This does NOT do bound checking.
|
||||
* @return New SparseVector with values in the order specified by the given indices.
|
||||
*
|
||||
* NOTE: The API needs to be discussed before making this public.
|
||||
* Also, if we have a version assuming indices are sorted, we should optimize it.
|
||||
*/
|
||||
private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
|
||||
var currentIdx = 0
|
||||
val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
|
||||
val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
|
||||
val i_v = if (iIdx >= 0) {
|
||||
Iterator((currentIdx, this.values(iIdx)))
|
||||
} else {
|
||||
Iterator()
|
||||
}
|
||||
currentIdx += 1
|
||||
i_v
|
||||
}.unzip
|
||||
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
object SparseVector {
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
||||
|
||||
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
test("params") {
|
||||
val slicer = new VectorSlicer
|
||||
ParamsSuite.checkParams(slicer)
|
||||
assert(slicer.getIndices.length === 0)
|
||||
assert(slicer.getNames.length === 0)
|
||||
withClue("VectorSlicer should not have any features selected by default") {
|
||||
intercept[IllegalArgumentException] {
|
||||
slicer.validateParams()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("feature validity checks") {
|
||||
import VectorSlicer._
|
||||
assert(validIndices(Array(0, 1, 8, 2)))
|
||||
assert(validIndices(Array.empty[Int]))
|
||||
assert(!validIndices(Array(-1)))
|
||||
assert(!validIndices(Array(1, 2, 1)))
|
||||
|
||||
assert(validNames(Array("a", "b")))
|
||||
assert(validNames(Array.empty[String]))
|
||||
assert(!validNames(Array("", "b")))
|
||||
assert(!validNames(Array("a", "b", "a")))
|
||||
}
|
||||
|
||||
test("Test vector slicer") {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
|
||||
val data = Array(
|
||||
Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
|
||||
Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
|
||||
Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
|
||||
Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
|
||||
Vectors.sparse(5, Seq())
|
||||
)
|
||||
|
||||
// Expected after selecting indices 1, 4
|
||||
val expected = Array(
|
||||
Vectors.sparse(2, Seq((0, 2.3))),
|
||||
Vectors.dense(2.3, 1.0),
|
||||
Vectors.dense(0.0, 0.0),
|
||||
Vectors.dense(-1.1, 3.3),
|
||||
Vectors.sparse(2, Seq())
|
||||
)
|
||||
|
||||
val defaultAttr = NumericAttribute.defaultAttr
|
||||
val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
|
||||
val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
|
||||
|
||||
val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
|
||||
val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
|
||||
|
||||
val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
|
||||
val df = sqlContext.createDataFrame(rdd,
|
||||
StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
|
||||
|
||||
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
|
||||
|
||||
def validateResults(df: DataFrame): Unit = {
|
||||
df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 === vec2)
|
||||
}
|
||||
val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
|
||||
val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
|
||||
assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
|
||||
resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
|
||||
assert(a === b)
|
||||
}
|
||||
}
|
||||
|
||||
vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
|
||||
validateResults(vectorSlicer.transform(df))
|
||||
|
||||
vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
|
||||
validateResults(vectorSlicer.transform(df))
|
||||
|
||||
vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
|
||||
validateResults(vectorSlicer.transform(df))
|
||||
}
|
||||
}
|
|
@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
|
|||
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
|
||||
assert(sv1 === sv1c)
|
||||
}
|
||||
|
||||
test("SparseVector.slice") {
|
||||
val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
|
||||
assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
|
||||
assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
|
||||
assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue