[SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide

## What changes were proposed in this pull request?

User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139).

I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary.

## How was this patch tested?

Docs and examples only. Ran all examples locally using spark-submit.

Author: sethah <shendrickson@cloudera.com>

Closes #20332 from sethah/multiclass_summary_example.
This commit is contained in:
sethah 2018-01-30 09:02:16 +02:00 committed by Nick Pentreath
parent 8b983243e4
commit 5056877e8b
6 changed files with 164 additions and 33 deletions

View file

@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `DataFrame` in
`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
`LogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.
<div class="codetabs">
@ -97,10 +97,9 @@ only available on the driver.
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
This will likely change when multiclass classification is supported.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. The binary summary can be accessed via the
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@ -111,10 +110,9 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
Support for multiclass model summaries will be added in the future.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. The binary summary can be accessed via the
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
Continuing the earlier example:
@ -125,7 +123,8 @@ Continuing the earlier example:
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
provides a summary for a
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
In the case of binary classification, certain additional metrics are
available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).
Continuing the earlier example:
@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
**Examples**
The following example shows how to train a multiclass logistic regression
model with elastic net regularization.
model with elastic net regularization, as well as extract the multiclass
training summary for evaluating the model.
<div class="codetabs">

View file

@ -18,10 +18,9 @@
package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
@ -50,7 +49,7 @@ public class JavaLogisticRegressionSummaryExample {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();
// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
@ -58,21 +57,15 @@ public class JavaLogisticRegressionSummaryExample {
System.out.println(lossPerIteration);
}
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
// classification problem.
BinaryLogisticRegressionSummary binarySummary =
(BinaryLogisticRegressionSummary) trainingSummary;
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
Dataset<Row> roc = binarySummary.roc();
Dataset<Row> roc = trainingSummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());
System.out.println(trainingSummary.areaUnderROC());
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble(0);

View file

@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
@ -48,6 +49,67 @@ public class JavaMulticlassLogisticRegressionWithElasticNetExample {
// Print the coefficients and intercept for multinomial logistic regression
System.out.println("Coefficients: \n"
+ lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}
// for multiclass, we can inspect metrics on a per-label basis
System.out.println("False positive rate by label:");
int i = 0;
double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
for (double fpr : fprLabel) {
System.out.println("label " + i + ": " + fpr);
i++;
}
System.out.println("True positive rate by label:");
i = 0;
double[] tprLabel = trainingSummary.truePositiveRateByLabel();
for (double tpr : tprLabel) {
System.out.println("label " + i + ": " + tpr);
i++;
}
System.out.println("Precision by label:");
i = 0;
double[] precLabel = trainingSummary.precisionByLabel();
for (double prec : precLabel) {
System.out.println("label " + i + ": " + prec);
i++;
}
System.out.println("Recall by label:");
i = 0;
double[] recLabel = trainingSummary.recallByLabel();
for (double rec : recLabel) {
System.out.println("label " + i + ": " + rec);
i++;
}
System.out.println("F-measure by label:");
i = 0;
double[] fLabel = trainingSummary.fMeasureByLabel();
for (double f : fLabel) {
System.out.println("label " + i + ": " + f);
i++;
}
double accuracy = trainingSummary.accuracy();
double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
double truePositiveRate = trainingSummary.weightedTruePositiveRate();
double fMeasure = trainingSummary.weightedFMeasure();
double precision = trainingSummary.weightedPrecision();
double recall = trainingSummary.weightedRecall();
System.out.println("Accuracy: " + accuracy);
System.out.println("FPR: " + falsePositiveRate);
System.out.println("TPR: " + truePositiveRate);
System.out.println("F-measure: " + fMeasure);
System.out.println("Precision: " + precision);
System.out.println("Recall: " + recall);
// $example off$
spark.stop();

View file

@ -43,6 +43,44 @@ if __name__ == "__main__":
# Print the coefficients and intercept for multinomial logistic regression
print("Coefficients: \n" + str(lrModel.coefficientMatrix))
print("Intercept: " + str(lrModel.interceptVector))
trainingSummary = lrModel.summary
# Obtain the objective per iteration
objectiveHistory = trainingSummary.objectiveHistory
print("objectiveHistory:")
for objective in objectiveHistory:
print(objective)
# for multiclass, we can inspect metrics on a per-label basis
print("False positive rate by label:")
for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
print("label %d: %s" % (i, rate))
print("True positive rate by label:")
for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
print("label %d: %s" % (i, rate))
print("Precision by label:")
for i, prec in enumerate(trainingSummary.precisionByLabel):
print("label %d: %s" % (i, prec))
print("Recall by label:")
for i, rec in enumerate(trainingSummary.recallByLabel):
print("label %d: %s" % (i, rec))
print("F-measure by label:")
for i, f in enumerate(trainingSummary.fMeasureByLabel()):
print("label %d: %s" % (i, f))
accuracy = trainingSummary.accuracy
falsePositiveRate = trainingSummary.weightedFalsePositiveRate
truePositiveRate = trainingSummary.weightedTruePositiveRate
fMeasure = trainingSummary.weightedFMeasure()
precision = trainingSummary.weightedPrecision
recall = trainingSummary.weightedRecall
print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
% (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
# $example off$
spark.stop()

View file

@ -19,7 +19,7 @@
package org.apache.spark.examples.ml
// $example on$
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
import org.apache.spark.ml.classification.LogisticRegression
// $example off$
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample {
// $example on$
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
// example
val trainingSummary = lrModel.summary
val trainingSummary = lrModel.binarySummary
// Obtain the objective per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = binarySummary.roc
val roc = trainingSummary.roc
roc.show()
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
// Set the model threshold to maximize F-Measure
val fMeasure = binarySummary.fMeasureByThreshold
val fMeasure = trainingSummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
.select("threshold").head().getDouble(0)

View file

@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample {
// Print the coefficients and intercept for multinomial logistic regression
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
println(s"Intercepts: \n${lrModel.interceptVector}")
val trainingSummary = lrModel.summary
// Obtain the objective per iteration
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(println)
// for multiclass, we can inspect metrics on a per-label basis
println("False positive rate by label:")
trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}
println("True positive rate by label:")
trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}
println("Precision by label:")
trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
println(s"label $label: $prec")
}
println("Recall by label:")
trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
println(s"label $label: $rec")
}
println("F-measure by label:")
trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
println(s"label $label: $f")
}
val accuracy = trainingSummary.accuracy
val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
val truePositiveRate = trainingSummary.weightedTruePositiveRate
val fMeasure = trainingSummary.weightedFMeasure
val precision = trainingSummary.weightedPrecision
val recall = trainingSummary.weightedRecall
println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
// $example off$
spark.stop()