[SPARK-1434] [MLLIB] change labelParser from anonymous function to trait
This is a patch to address @mateiz 's comment in https://github.com/apache/spark/pull/245 MLUtils#loadLibSVMData uses an anonymous function for the label parser. Java users won't like it. So I make a trait for LabelParser and provide two implementations: binary and multiclass. Author: Xiangrui Meng <meng@databricks.com> Closes #345 from mengxr/label-parser and squashes the following commits: ac44409 [Xiangrui Meng] use singleton objects for label parsers 3b1a7c6 [Xiangrui Meng] add tests for label parsers c2e571c [Xiangrui Meng] rename LabelParser.apply to LabelParser.parse use extends for singleton 11c94e0 [Xiangrui Meng] add return types 7f8eb36 [Xiangrui Meng] change labelParser from annoymous function to trait
This commit is contained in:
parent
ce8ec54561
commit
b9e0c937df
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.mllib.util
|
||||||
|
|
||||||
|
/** Trait for label parsers. */
|
||||||
|
trait LabelParser extends Serializable {
|
||||||
|
/** Parses a string label into a double label. */
|
||||||
|
def parse(labelString: String): Double
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5,
|
||||||
|
* or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling.
|
||||||
|
*/
|
||||||
|
object BinaryLabelParser extends LabelParser {
|
||||||
|
/** Gets the default instance of BinaryLabelParser. */
|
||||||
|
def getInstance(): LabelParser = this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the input label into positive (1.0) if the value is greater than 0.5,
|
||||||
|
* or negative (0.0) otherwise.
|
||||||
|
*/
|
||||||
|
override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Label parser for multiclass labels, which converts the input label to double.
|
||||||
|
*/
|
||||||
|
object MulticlassLabelParser extends LabelParser {
|
||||||
|
/** Gets the default instance of MulticlassLabelParser. */
|
||||||
|
def getInstance(): LabelParser = this
|
||||||
|
|
||||||
|
override def parse(labelString: String): Double = labelString.toDouble
|
||||||
|
}
|
|
@ -38,17 +38,6 @@ object MLUtils {
|
||||||
eps
|
eps
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Multiclass label parser, which parses a string into double.
|
|
||||||
*/
|
|
||||||
val multiclassLabelParser: String => Double = _.toDouble
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
|
|
||||||
* or 0.0 (negative) otherwise.
|
|
||||||
*/
|
|
||||||
val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
|
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
|
||||||
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
|
* The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
|
||||||
|
@ -69,7 +58,7 @@ object MLUtils {
|
||||||
def loadLibSVMData(
|
def loadLibSVMData(
|
||||||
sc: SparkContext,
|
sc: SparkContext,
|
||||||
path: String,
|
path: String,
|
||||||
labelParser: String => Double,
|
labelParser: LabelParser,
|
||||||
numFeatures: Int,
|
numFeatures: Int,
|
||||||
minSplits: Int): RDD[LabeledPoint] = {
|
minSplits: Int): RDD[LabeledPoint] = {
|
||||||
val parsed = sc.textFile(path, minSplits)
|
val parsed = sc.textFile(path, minSplits)
|
||||||
|
@ -89,7 +78,7 @@ object MLUtils {
|
||||||
}.reduce(math.max)
|
}.reduce(math.max)
|
||||||
}
|
}
|
||||||
parsed.map { items =>
|
parsed.map { items =>
|
||||||
val label = labelParser(items.head)
|
val label = labelParser.parse(items.head)
|
||||||
val (indices, values) = items.tail.map { item =>
|
val (indices, values) = items.tail.map { item =>
|
||||||
val indexAndValue = item.split(':')
|
val indexAndValue = item.split(':')
|
||||||
val index = indexAndValue(0).toInt - 1
|
val index = indexAndValue(0).toInt - 1
|
||||||
|
@ -107,14 +96,7 @@ object MLUtils {
|
||||||
* with number of features determined automatically and the default number of partitions.
|
* with number of features determined automatically and the default number of partitions.
|
||||||
*/
|
*/
|
||||||
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
|
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
|
||||||
loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)
|
loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits)
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
|
|
||||||
* with number of features specified explicitly and the default number of partitions.
|
|
||||||
*/
|
|
||||||
def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
|
|
||||||
loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
|
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
|
||||||
|
@ -124,7 +106,7 @@ object MLUtils {
|
||||||
def loadLibSVMData(
|
def loadLibSVMData(
|
||||||
sc: SparkContext,
|
sc: SparkContext,
|
||||||
path: String,
|
path: String,
|
||||||
labelParser: String => Double): RDD[LabeledPoint] =
|
labelParser: LabelParser): RDD[LabeledPoint] =
|
||||||
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)
|
loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,7 +117,7 @@ object MLUtils {
|
||||||
def loadLibSVMData(
|
def loadLibSVMData(
|
||||||
sc: SparkContext,
|
sc: SparkContext,
|
||||||
path: String,
|
path: String,
|
||||||
labelParser: String => Double,
|
labelParser: LabelParser,
|
||||||
numFeatures: Int): RDD[LabeledPoint] =
|
numFeatures: Int): RDD[LabeledPoint] =
|
||||||
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
|
loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.mllib.util
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class LabelParsersSuite extends FunSuite {
|
||||||
|
test("binary label parser") {
|
||||||
|
for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) {
|
||||||
|
assert(parser.parse("+1") === 1.0)
|
||||||
|
assert(parser.parse("1") === 1.0)
|
||||||
|
assert(parser.parse("0") === 0.0)
|
||||||
|
assert(parser.parse("-1") === 0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("multiclass label parser") {
|
||||||
|
for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) {
|
||||||
|
assert(parser.parse("0") == 0.0)
|
||||||
|
assert(parser.parse("+1") === 1.0)
|
||||||
|
assert(parser.parse("1") === 1.0)
|
||||||
|
assert(parser.parse("2") === 2.0)
|
||||||
|
assert(parser.parse("3") === 3.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
|
||||||
Files.write(lines, file, Charsets.US_ASCII)
|
Files.write(lines, file, Charsets.US_ASCII)
|
||||||
val path = tempDir.toURI.toString
|
val path = tempDir.toURI.toString
|
||||||
|
|
||||||
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
|
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
|
||||||
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
|
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
|
||||||
|
|
||||||
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
|
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
|
||||||
|
@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
|
||||||
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
|
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
|
||||||
}
|
}
|
||||||
|
|
||||||
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
|
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
|
||||||
assert(multiclassPoints.length === 3)
|
assert(multiclassPoints.length === 3)
|
||||||
assert(multiclassPoints(0).label === 1.0)
|
assert(multiclassPoints(0).label === 1.0)
|
||||||
assert(multiclassPoints(1).label === -1.0)
|
assert(multiclassPoints(1).label === -1.0)
|
||||||
|
|
Loading…
Reference in a new issue