[SPARK-23042][ML] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier

## What changes were proposed in this pull request?

In MultilayerPerceptronClassifier, we use RDD operation to encode labels for now. I think we should use ML's OneHotEncoderEstimator/Model to do the encoding.

## How was this patch tested?

Existing tests.

Closes #20232 from viirya/SPARK-23042.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: DB Tsai <d_tsai@apple.com>
This commit is contained in:
Liang-Chi Hsieh 2018-08-17 18:40:29 +00:00 committed by DB Tsai
parent 162326c0ee
commit 8b0e94d896
No known key found for this signature in database
GPG key ID: E6FD79DA81FE14FD
5 changed files with 26 additions and 39 deletions

View file

@ -382,10 +382,10 @@ test_that("spark.mlp", {
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ]) traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2))
predictions <- predict(model, testdf) predictions <- predict(model, testdf)
expect_error(collect(predictions)) expect_error(collect(predictions))
model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 2), handleInvalid = "skip")
predictions <- predict(model, testdf) predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "list") expect_equal(class(collect(predictions)$clicked[1]), "list")

View file

@ -654,7 +654,7 @@ We use Titanic data set to show how to use `spark.mlp` in classification.
t <- as.data.frame(Titanic) t <- as.data.frame(Titanic)
training <- createDataFrame(t) training <- createDataFrame(t)
# fit a Multilayer Perceptron Classification Model # fit a Multilayer Perceptron Classification Model
model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 2), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 5, 5, 9, 9))
``` ```
To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell.

View file

@ -667,3 +667,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma
## Upgrading to SparkR 2.3.1 and above ## Upgrading to SparkR 2.3.1 and above
- In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1.
## Upgrading to SparkR 2.4.0
- Previously, we don't check the validity of the size of the last layer in `spark.mlp`. For example, if the training data only has two labels, a `layers` param like `c(1, 3)` doesn't cause an error previously, now it does.

View file

@ -23,13 +23,13 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.OneHotEncoderModel
import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.Dataset import org.apache.spark.sql.{Dataset, Row}
/** Params for Multilayer Perceptron. */ /** Params for Multilayer Perceptron. */
private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams
@ -103,36 +103,6 @@ private[classification] trait MultilayerPerceptronParams extends ProbabilisticCl
solver -> LBFGS, stepSize -> 0.03) solver -> LBFGS, stepSize -> 0.03)
} }
/** Label to vector converter. */
private object LabelConverter {
// TODO: Use OneHotEncoder instead
/**
* Encodes a label as a vector.
* Returns a vector of given length with zeroes at all positions
* and value 1.0 at the position that corresponds to the label.
*
* @param labeledPoint labeled point
* @param labelCount total number of labels
* @return pair of features and vector encoding of a label
*/
def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = {
val output = Array.fill(labelCount)(0.0)
output(labeledPoint.label.toInt) = 1.0
(labeledPoint.features, Vectors.dense(output))
}
/**
* Converts a vector to a label.
* Returns the position of the maximal element of a vector.
*
* @param output label encoded with a vector
* @return label
*/
def decodeLabel(output: Vector): Double = {
output.argmax.toDouble
}
}
/** /**
* Classifier trainer based on the Multilayer Perceptron. * Classifier trainer based on the Multilayer Perceptron.
* Each layer has sigmoid activation function, output layer has softmax. * Each layer has sigmoid activation function, output layer has softmax.
@ -243,8 +213,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
instr.logNumClasses(labels) instr.logNumClasses(labels)
instr.logNumFeatures(myLayers.head) instr.logNumFeatures(myLayers.head)
val lpData = extractLabeledPoints(dataset) // One-hot encoding for labels using OneHotEncoderModel.
val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) // As we already know the length of encoding, we skip fitting and directly create
// the model.
val encodedLabelCol = "_encoded" + $(labelCol)
val encodeModel = new OneHotEncoderModel(uid, Array(labels))
.setInputCols(Array($(labelCol)))
.setOutputCols(Array(encodedLabelCol))
.setDropLast(false)
val encodedDataset = encodeModel.transform(dataset)
val data = encodedDataset.select($(featuresCol), encodedLabelCol).rdd.map {
case Row(features: Vector, encodedLabel: Vector) => (features, encodedLabel)
}
val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true) val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true)
val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
if (isDefined(initialWeights)) { if (isDefined(initialWeights)) {
@ -323,7 +303,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
* This internal method is used to implement `transform()` and output [[predictionCol]]. * This internal method is used to implement `transform()` and output [[predictionCol]].
*/ */
override def predict(features: Vector): Double = { override def predict(features: Vector): Double = {
LabelConverter.decodeLabel(mlpModel.predict(features)) mlpModel.predict(features).argmax.toDouble
} }
@Since("1.5.0") @Since("1.5.0")

View file

@ -97,7 +97,10 @@ object MimaExcludes {
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
// [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter")
) )
// Exclude rules for 2.3.x // Exclude rules for 2.3.x