[SPARK-31925][ML] Summary.totalIterations greater than maxIters

### What changes were proposed in this pull request?
In LogisticRegression and LinearRegression, if set maxIter=n, the model.summary.totalIterations returns  n+1 if the training procedure does not drop out. This is because we use ```objectiveHistory.length``` as totalIterations, but ```objectiveHistory``` contains init sate, thus ```objectiveHistory.length``` is 1 larger than number of training iterations.

### Why are the changes needed?
correctness

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
add new tests and also modify existing tests

Closes #28786 from huaxingao/summary_iter.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Huaxin Gao 2020-06-15 08:49:03 -05:00 committed by Sean Owen
parent 9d95f1b010
commit f83cb3cbb3
7 changed files with 43 additions and 12 deletions

View file

@ -594,7 +594,7 @@ class LogisticRegression @Since("1.2.0") (
Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
}
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty)
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array(0.0))
}
if (!$(fitIntercept) && isConstantLabel) {
@ -1545,13 +1545,19 @@ sealed trait LogisticRegressionSummary extends Serializable {
*/
sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
/** objective function (scaled loss + regularization) at each iteration. */
/**
* objective function (scaled loss + regularization) at each iteration.
* It contains one more element, the initial state, than number of iterations.
*/
@Since("1.5.0")
def objectiveHistory: Array[Double]
/** Number of training iterations. */
@Since("1.5.0")
def totalIterations: Int = objectiveHistory.length
def totalIterations: Int = {
assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.")
objectiveHistory.length - 1
}
}

View file

@ -433,7 +433,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
Vectors.dense(Array.fill(dim)(1.0))
}
val (parameters, objectiveHistory) = if ($(blockSize) == 1) {
val (parameters, objectiveHistory) = if ($(blockSize) == 1) {
trainOnRows(instances, yMean, yStd, featuresMean, featuresStd,
initialValues, regularization, optimizer)
} else {
@ -939,8 +939,10 @@ class LinearRegressionTrainingSummary private[regression] (
* @see `LinearRegression.solver`
*/
@Since("1.5.0")
val totalIterations = objectiveHistory.length
val totalIterations = {
assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.")
objectiveHistory.length - 1
}
}
/**

View file

@ -143,6 +143,6 @@ public class JavaLogisticRegressionSuite extends SharedSparkSession {
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegressionTrainingSummary summary = model.summary();
Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length - 1);
}
}

View file

@ -266,6 +266,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
assert(blorModel.summary.totalIterations == 1)
assert(blorModel.binarySummary.totalIterations == 1)
val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
@ -279,6 +281,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
mlorModel.summary.asBinary
}
}
assert(mlorModel.summary.totalIterations == 1)
val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
@ -2570,7 +2573,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
rows.map(_.getDouble(0)).toArray === binaryExpected
}
}
assert(model2.summary.totalIterations === 1)
assert(model2.summary.totalIterations === 0)
val lr3 = new LogisticRegression().setFamily("multinomial")
val model3 = lr3.fit(smallMultinomialDataset)
@ -2585,7 +2588,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
rows.map(_.getDouble(0)).toArray === multinomialExpected
}
}
assert(model4.summary.totalIterations === 1)
assert(model4.summary.totalIterations === 0)
}
test("binary logistic regression with all labels the same") {
@ -2605,6 +2608,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
assert(allZeroInterceptModel.intercept === Double.NegativeInfinity)
assert(allZeroInterceptModel.summary.totalIterations === 0)
assert(allZeroInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
val allOneInterceptModel = lrIntercept
.setLabelCol("oneLabel")
@ -2612,6 +2616,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
assert(allOneInterceptModel.intercept === Double.PositiveInfinity)
assert(allOneInterceptModel.summary.totalIterations === 0)
assert(allOneInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
// fitIntercept=false
val lrNoIntercept = new LogisticRegression()
@ -2647,6 +2652,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(pred === 4.0)
}
assert(model.summary.totalIterations === 0)
assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
// force the model to be trained with only one class
val constantZeroData = Seq(
@ -2660,7 +2666,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(prob === Vectors.dense(Array(1.0)))
assert(pred === 0.0)
}
assert(modelZeroLabel.summary.totalIterations > 0)
assert(modelZeroLabel.summary.totalIterations === 0)
// ensure that the correct value is predicted when numClasses passed through metadata
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
@ -2675,6 +2681,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(pred === 4.0)
}
require(modelWithMetadata.summary.totalIterations === 0)
assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
}
test("compressed storage for constant label") {

View file

@ -761,6 +761,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
.fit(datasetWithWeightConstantLabel)
if (fitIntercept) {
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
assert(model1.summary.totalIterations === 0)
}
val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
@ -768,6 +769,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
.setSolver("l-bfgs")
.fit(datasetWithWeightZeroLabel)
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
assert(model2.summary.totalIterations === 0)
}
}
}
@ -940,6 +942,19 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
}
}
test("linear regression training summary totalIterations") {
Seq(1, 5, 10, 20).foreach { maxIter =>
val trainer = new LinearRegression().setSolver("l-bfgs").setMaxIter(maxIter)
val model = trainer.fit(datasetWithDenseFeature)
assert(model.summary.totalIterations <= maxIter)
}
Seq("auto", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)
val model = trainer.fit(datasetWithDenseFeature)
assert(model.summary.totalIterations === 0)
}
}
test("linear regression with weighted samples") {
val sqlContext = spark.sqlContext
import sqlContext.implicits._

View file

@ -1119,7 +1119,8 @@ class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
def objectiveHistory(self):
"""
Objective function (scaled loss + regularization) at each
iteration.
iteration. It contains one more element, the initial state,
than number of iterations.
"""
return self._call_java("objectiveHistory")

View file

@ -42,7 +42,7 @@ class TrainingSummaryTest(SparkSessionTestCase):
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertGreater(s.totalIterations, 0)
self.assertEqual(s.totalIterations, 0)
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.predictionCol, "prediction")
self.assertEqual(s.labelCol, "label")