[SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide
## What changes were proposed in this pull request? User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139). I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary. ## How was this patch tested? Docs and examples only. Ran all examples locally using spark-submit. Author: sethah <shendrickson@cloudera.com> Closes #20332 from sethah/multiclass_summary_example.
This commit is contained in:
parent
8b983243e4
commit
5056877e8b
|
@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark
|
|||
The `spark.ml` implementation of logistic regression also supports
|
||||
extracting a summary of the model over the training set. Note that the
|
||||
predictions and metrics which are stored as `DataFrame` in
|
||||
`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
|
||||
`LogisticRegressionSummary` are annotated `@transient` and hence
|
||||
only available on the driver.
|
||||
|
||||
<div class="codetabs">
|
||||
|
@ -97,10 +97,9 @@ only available on the driver.
|
|||
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
|
||||
provides a summary for a
|
||||
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
|
||||
Currently, only binary classification is supported and the
|
||||
summary must be explicitly cast to
|
||||
[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
|
||||
This will likely change when multiclass classification is supported.
|
||||
In the case of binary classification, certain additional metrics are
|
||||
available, e.g. ROC curve. The binary summary can be accessed via the
|
||||
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
|
||||
|
||||
Continuing the earlier example:
|
||||
|
||||
|
@ -111,10 +110,9 @@ Continuing the earlier example:
|
|||
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
|
||||
provides a summary for a
|
||||
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
|
||||
Currently, only binary classification is supported and the
|
||||
summary must be explicitly cast to
|
||||
[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
|
||||
Support for multiclass model summaries will be added in the future.
|
||||
In the case of binary classification, certain additional metrics are
|
||||
available, e.g. ROC curve. The binary summary can be accessed via the
|
||||
`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
|
||||
|
||||
Continuing the earlier example:
|
||||
|
||||
|
@ -125,7 +123,8 @@ Continuing the earlier example:
|
|||
[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary)
|
||||
provides a summary for a
|
||||
[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel).
|
||||
Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future.
|
||||
In the case of binary classification, certain additional metrics are
|
||||
available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary).
|
||||
|
||||
Continuing the earlier example:
|
||||
|
||||
|
@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin
|
|||
**Examples**
|
||||
|
||||
The following example shows how to train a multiclass logistic regression
|
||||
model with elastic net regularization.
|
||||
model with elastic net regularization, as well as extract the multiclass
|
||||
training summary for evaluating the model.
|
||||
|
||||
<div class="codetabs">
|
||||
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
package org.apache.spark.examples.ml;
|
||||
|
||||
// $example on$
|
||||
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
|
||||
import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary;
|
||||
import org.apache.spark.ml.classification.LogisticRegression;
|
||||
import org.apache.spark.ml.classification.LogisticRegressionModel;
|
||||
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
|
@ -50,7 +49,7 @@ public class JavaLogisticRegressionSummaryExample {
|
|||
// $example on$
|
||||
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
|
||||
// example
|
||||
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
|
||||
BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary();
|
||||
|
||||
// Obtain the loss per iteration.
|
||||
double[] objectiveHistory = trainingSummary.objectiveHistory();
|
||||
|
@ -58,21 +57,15 @@ public class JavaLogisticRegressionSummaryExample {
|
|||
System.out.println(lossPerIteration);
|
||||
}
|
||||
|
||||
// Obtain the metrics useful to judge performance on test data.
|
||||
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
|
||||
// classification problem.
|
||||
BinaryLogisticRegressionSummary binarySummary =
|
||||
(BinaryLogisticRegressionSummary) trainingSummary;
|
||||
|
||||
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
|
||||
Dataset<Row> roc = binarySummary.roc();
|
||||
Dataset<Row> roc = trainingSummary.roc();
|
||||
roc.show();
|
||||
roc.select("FPR").show();
|
||||
System.out.println(binarySummary.areaUnderROC());
|
||||
System.out.println(trainingSummary.areaUnderROC());
|
||||
|
||||
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
|
||||
// this selected threshold.
|
||||
Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
|
||||
Dataset<Row> fMeasure = trainingSummary.fMeasureByThreshold();
|
||||
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
|
||||
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
|
||||
.select("threshold").head().getDouble(0);
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
|
|||
// $example on$
|
||||
import org.apache.spark.ml.classification.LogisticRegression;
|
||||
import org.apache.spark.ml.classification.LogisticRegressionModel;
|
||||
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
|
@ -48,6 +49,67 @@ public class JavaMulticlassLogisticRegressionWithElasticNetExample {
|
|||
// Print the coefficients and intercept for multinomial logistic regression
|
||||
System.out.println("Coefficients: \n"
|
||||
+ lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector());
|
||||
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
|
||||
|
||||
// Obtain the loss per iteration.
|
||||
double[] objectiveHistory = trainingSummary.objectiveHistory();
|
||||
for (double lossPerIteration : objectiveHistory) {
|
||||
System.out.println(lossPerIteration);
|
||||
}
|
||||
|
||||
// for multiclass, we can inspect metrics on a per-label basis
|
||||
System.out.println("False positive rate by label:");
|
||||
int i = 0;
|
||||
double[] fprLabel = trainingSummary.falsePositiveRateByLabel();
|
||||
for (double fpr : fprLabel) {
|
||||
System.out.println("label " + i + ": " + fpr);
|
||||
i++;
|
||||
}
|
||||
|
||||
System.out.println("True positive rate by label:");
|
||||
i = 0;
|
||||
double[] tprLabel = trainingSummary.truePositiveRateByLabel();
|
||||
for (double tpr : tprLabel) {
|
||||
System.out.println("label " + i + ": " + tpr);
|
||||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Precision by label:");
|
||||
i = 0;
|
||||
double[] precLabel = trainingSummary.precisionByLabel();
|
||||
for (double prec : precLabel) {
|
||||
System.out.println("label " + i + ": " + prec);
|
||||
i++;
|
||||
}
|
||||
|
||||
System.out.println("Recall by label:");
|
||||
i = 0;
|
||||
double[] recLabel = trainingSummary.recallByLabel();
|
||||
for (double rec : recLabel) {
|
||||
System.out.println("label " + i + ": " + rec);
|
||||
i++;
|
||||
}
|
||||
|
||||
System.out.println("F-measure by label:");
|
||||
i = 0;
|
||||
double[] fLabel = trainingSummary.fMeasureByLabel();
|
||||
for (double f : fLabel) {
|
||||
System.out.println("label " + i + ": " + f);
|
||||
i++;
|
||||
}
|
||||
|
||||
double accuracy = trainingSummary.accuracy();
|
||||
double falsePositiveRate = trainingSummary.weightedFalsePositiveRate();
|
||||
double truePositiveRate = trainingSummary.weightedTruePositiveRate();
|
||||
double fMeasure = trainingSummary.weightedFMeasure();
|
||||
double precision = trainingSummary.weightedPrecision();
|
||||
double recall = trainingSummary.weightedRecall();
|
||||
System.out.println("Accuracy: " + accuracy);
|
||||
System.out.println("FPR: " + falsePositiveRate);
|
||||
System.out.println("TPR: " + truePositiveRate);
|
||||
System.out.println("F-measure: " + fMeasure);
|
||||
System.out.println("Precision: " + precision);
|
||||
System.out.println("Recall: " + recall);
|
||||
// $example off$
|
||||
|
||||
spark.stop();
|
||||
|
|
|
@ -43,6 +43,44 @@ if __name__ == "__main__":
|
|||
# Print the coefficients and intercept for multinomial logistic regression
|
||||
print("Coefficients: \n" + str(lrModel.coefficientMatrix))
|
||||
print("Intercept: " + str(lrModel.interceptVector))
|
||||
|
||||
trainingSummary = lrModel.summary
|
||||
|
||||
# Obtain the objective per iteration
|
||||
objectiveHistory = trainingSummary.objectiveHistory
|
||||
print("objectiveHistory:")
|
||||
for objective in objectiveHistory:
|
||||
print(objective)
|
||||
|
||||
# for multiclass, we can inspect metrics on a per-label basis
|
||||
print("False positive rate by label:")
|
||||
for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):
|
||||
print("label %d: %s" % (i, rate))
|
||||
|
||||
print("True positive rate by label:")
|
||||
for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):
|
||||
print("label %d: %s" % (i, rate))
|
||||
|
||||
print("Precision by label:")
|
||||
for i, prec in enumerate(trainingSummary.precisionByLabel):
|
||||
print("label %d: %s" % (i, prec))
|
||||
|
||||
print("Recall by label:")
|
||||
for i, rec in enumerate(trainingSummary.recallByLabel):
|
||||
print("label %d: %s" % (i, rec))
|
||||
|
||||
print("F-measure by label:")
|
||||
for i, f in enumerate(trainingSummary.fMeasureByLabel()):
|
||||
print("label %d: %s" % (i, f))
|
||||
|
||||
accuracy = trainingSummary.accuracy
|
||||
falsePositiveRate = trainingSummary.weightedFalsePositiveRate
|
||||
truePositiveRate = trainingSummary.weightedTruePositiveRate
|
||||
fMeasure = trainingSummary.weightedFMeasure()
|
||||
precision = trainingSummary.weightedPrecision
|
||||
recall = trainingSummary.weightedRecall
|
||||
print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s"
|
||||
% (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))
|
||||
# $example off$
|
||||
|
||||
spark.stop()
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
package org.apache.spark.examples.ml
|
||||
|
||||
// $example on$
|
||||
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
|
||||
import org.apache.spark.ml.classification.LogisticRegression
|
||||
// $example off$
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.functions.max
|
||||
|
@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample {
|
|||
// $example on$
|
||||
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
|
||||
// example
|
||||
val trainingSummary = lrModel.summary
|
||||
val trainingSummary = lrModel.binarySummary
|
||||
|
||||
// Obtain the objective per iteration.
|
||||
val objectiveHistory = trainingSummary.objectiveHistory
|
||||
println("objectiveHistory:")
|
||||
objectiveHistory.foreach(loss => println(loss))
|
||||
|
||||
// Obtain the metrics useful to judge performance on test data.
|
||||
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
|
||||
// binary classification problem.
|
||||
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
|
||||
|
||||
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
|
||||
val roc = binarySummary.roc
|
||||
val roc = trainingSummary.roc
|
||||
roc.show()
|
||||
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
|
||||
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
|
||||
|
||||
// Set the model threshold to maximize F-Measure
|
||||
val fMeasure = binarySummary.fMeasureByThreshold
|
||||
val fMeasure = trainingSummary.fMeasureByThreshold
|
||||
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
|
||||
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
|
||||
.select("threshold").head().getDouble(0)
|
||||
|
|
|
@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample {
|
|||
// Print the coefficients and intercept for multinomial logistic regression
|
||||
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
|
||||
println(s"Intercepts: \n${lrModel.interceptVector}")
|
||||
|
||||
val trainingSummary = lrModel.summary
|
||||
|
||||
// Obtain the objective per iteration
|
||||
val objectiveHistory = trainingSummary.objectiveHistory
|
||||
println("objectiveHistory:")
|
||||
objectiveHistory.foreach(println)
|
||||
|
||||
// for multiclass, we can inspect metrics on a per-label basis
|
||||
println("False positive rate by label:")
|
||||
trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
|
||||
println(s"label $label: $rate")
|
||||
}
|
||||
|
||||
println("True positive rate by label:")
|
||||
trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
|
||||
println(s"label $label: $rate")
|
||||
}
|
||||
|
||||
println("Precision by label:")
|
||||
trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
|
||||
println(s"label $label: $prec")
|
||||
}
|
||||
|
||||
println("Recall by label:")
|
||||
trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
|
||||
println(s"label $label: $rec")
|
||||
}
|
||||
|
||||
|
||||
println("F-measure by label:")
|
||||
trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
|
||||
println(s"label $label: $f")
|
||||
}
|
||||
|
||||
val accuracy = trainingSummary.accuracy
|
||||
val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
|
||||
val truePositiveRate = trainingSummary.weightedTruePositiveRate
|
||||
val fMeasure = trainingSummary.weightedFMeasure
|
||||
val precision = trainingSummary.weightedPrecision
|
||||
val recall = trainingSummary.weightedRecall
|
||||
println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
|
||||
s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
|
||||
// $example off$
|
||||
|
||||
spark.stop()
|
||||
|
|
Loading…
Reference in a new issue