[SPARK-17139][ML][FOLLOW-UP] Add convenient method asBinary for casting to BinaryLogisticRegressionSummary

## What changes were proposed in this pull request?

add an "asBinary" method to LogisticRegressionSummary for convenient casting to BinaryLogisticRegressionSummary.

## How was this patch tested?

Testcase updated.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19072 from WeichenXu123/mlor_summary_as_binary.
This commit is contained in:
WeichenXu 2017-08-31 16:22:40 -07:00 committed by Joseph K. Bradley
parent cba69aeb45
commit 96028e36b4
3 changed files with 18 additions and 0 deletions

View file

@ -1473,6 +1473,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Returns weighted averaged f1-measure. */
@Since("2.3.0")
def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
/**
* Convenient method for casting to binary logistic regression summary.
* This method will throws an Exception if the summary is not a binary summary.
*/
@Since("2.3.0")
def asBinary: BinaryLogisticRegressionSummary = this match {
case b: BinaryLogisticRegressionSummary => b
case _ =>
throw new RuntimeException("Cannot cast to a binary summary.")
}
}
/**

View file

@ -256,6 +256,7 @@ class LogisticRegressionSuite
val blorModel = lr.fit(smallBinaryDataset)
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
@ -265,6 +266,11 @@ class LogisticRegressionSuite
mlorModel.binarySummary
}
}
withClue("cannot cast summary to binary summary multiclass model") {
intercept[RuntimeException] {
mlorModel.summary.asBinary
}
}
val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])

View file

@ -62,6 +62,7 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=")
)