[SPARK-31925][ML] Summary.totalIterations greater than maxIters
### What changes were proposed in this pull request? In LogisticRegression and LinearRegression, if set maxIter=n, the model.summary.totalIterations returns n+1 if the training procedure does not drop out. This is because we use ```objectiveHistory.length``` as totalIterations, but ```objectiveHistory``` contains init sate, thus ```objectiveHistory.length``` is 1 larger than number of training iterations. ### Why are the changes needed? correctness ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new tests and also modify existing tests Closes #28786 from huaxingao/summary_iter. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
9d95f1b010
commit
f83cb3cbb3
|
@ -594,7 +594,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
|
||||
}
|
||||
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
|
||||
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty)
|
||||
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array(0.0))
|
||||
}
|
||||
|
||||
if (!$(fitIntercept) && isConstantLabel) {
|
||||
|
@ -1545,13 +1545,19 @@ sealed trait LogisticRegressionSummary extends Serializable {
|
|||
*/
|
||||
sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
|
||||
|
||||
/** objective function (scaled loss + regularization) at each iteration. */
|
||||
/**
|
||||
* objective function (scaled loss + regularization) at each iteration.
|
||||
* It contains one more element, the initial state, than number of iterations.
|
||||
*/
|
||||
@Since("1.5.0")
|
||||
def objectiveHistory: Array[Double]
|
||||
|
||||
/** Number of training iterations. */
|
||||
@Since("1.5.0")
|
||||
def totalIterations: Int = objectiveHistory.length
|
||||
def totalIterations: Int = {
|
||||
assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.")
|
||||
objectiveHistory.length - 1
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -433,7 +433,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
Vectors.dense(Array.fill(dim)(1.0))
|
||||
}
|
||||
|
||||
val (parameters, objectiveHistory) = if ($(blockSize) == 1) {
|
||||
val (parameters, objectiveHistory) = if ($(blockSize) == 1) {
|
||||
trainOnRows(instances, yMean, yStd, featuresMean, featuresStd,
|
||||
initialValues, regularization, optimizer)
|
||||
} else {
|
||||
|
@ -939,8 +939,10 @@ class LinearRegressionTrainingSummary private[regression] (
|
|||
* @see `LinearRegression.solver`
|
||||
*/
|
||||
@Since("1.5.0")
|
||||
val totalIterations = objectiveHistory.length
|
||||
|
||||
val totalIterations = {
|
||||
assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.")
|
||||
objectiveHistory.length - 1
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -143,6 +143,6 @@ public class JavaLogisticRegressionSuite extends SharedSparkSession {
|
|||
LogisticRegressionModel model = lr.fit(dataset);
|
||||
|
||||
LogisticRegressionTrainingSummary summary = model.summary();
|
||||
Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
|
||||
Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length - 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -266,6 +266,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
|
||||
assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
|
||||
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
|
||||
assert(blorModel.summary.totalIterations == 1)
|
||||
assert(blorModel.binarySummary.totalIterations == 1)
|
||||
|
||||
val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
|
||||
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
|
||||
|
@ -279,6 +281,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
mlorModel.summary.asBinary
|
||||
}
|
||||
}
|
||||
assert(mlorModel.summary.totalIterations == 1)
|
||||
|
||||
val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
|
||||
assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
|
||||
|
@ -2570,7 +2573,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
rows.map(_.getDouble(0)).toArray === binaryExpected
|
||||
}
|
||||
}
|
||||
assert(model2.summary.totalIterations === 1)
|
||||
assert(model2.summary.totalIterations === 0)
|
||||
|
||||
val lr3 = new LogisticRegression().setFamily("multinomial")
|
||||
val model3 = lr3.fit(smallMultinomialDataset)
|
||||
|
@ -2585,7 +2588,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
rows.map(_.getDouble(0)).toArray === multinomialExpected
|
||||
}
|
||||
}
|
||||
assert(model4.summary.totalIterations === 1)
|
||||
assert(model4.summary.totalIterations === 0)
|
||||
}
|
||||
|
||||
test("binary logistic regression with all labels the same") {
|
||||
|
@ -2605,6 +2608,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
|
||||
assert(allZeroInterceptModel.intercept === Double.NegativeInfinity)
|
||||
assert(allZeroInterceptModel.summary.totalIterations === 0)
|
||||
assert(allZeroInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
|
||||
val allOneInterceptModel = lrIntercept
|
||||
.setLabelCol("oneLabel")
|
||||
|
@ -2612,6 +2616,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3)
|
||||
assert(allOneInterceptModel.intercept === Double.PositiveInfinity)
|
||||
assert(allOneInterceptModel.summary.totalIterations === 0)
|
||||
assert(allOneInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
|
||||
// fitIntercept=false
|
||||
val lrNoIntercept = new LogisticRegression()
|
||||
|
@ -2647,6 +2652,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(pred === 4.0)
|
||||
}
|
||||
assert(model.summary.totalIterations === 0)
|
||||
assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
|
||||
// force the model to be trained with only one class
|
||||
val constantZeroData = Seq(
|
||||
|
@ -2660,7 +2666,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(prob === Vectors.dense(Array(1.0)))
|
||||
assert(pred === 0.0)
|
||||
}
|
||||
assert(modelZeroLabel.summary.totalIterations > 0)
|
||||
assert(modelZeroLabel.summary.totalIterations === 0)
|
||||
|
||||
// ensure that the correct value is predicted when numClasses passed through metadata
|
||||
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
|
||||
|
@ -2675,6 +2681,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(pred === 4.0)
|
||||
}
|
||||
require(modelWithMetadata.summary.totalIterations === 0)
|
||||
assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
}
|
||||
|
||||
test("compressed storage for constant label") {
|
||||
|
|
|
@ -761,6 +761,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
|
|||
.fit(datasetWithWeightConstantLabel)
|
||||
if (fitIntercept) {
|
||||
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
assert(model1.summary.totalIterations === 0)
|
||||
}
|
||||
val model2 = new LinearRegression()
|
||||
.setFitIntercept(fitIntercept)
|
||||
|
@ -768,6 +769,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
|
|||
.setSolver("l-bfgs")
|
||||
.fit(datasetWithWeightZeroLabel)
|
||||
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
|
||||
assert(model2.summary.totalIterations === 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -940,6 +942,19 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
|
|||
}
|
||||
}
|
||||
|
||||
test("linear regression training summary totalIterations") {
|
||||
Seq(1, 5, 10, 20).foreach { maxIter =>
|
||||
val trainer = new LinearRegression().setSolver("l-bfgs").setMaxIter(maxIter)
|
||||
val model = trainer.fit(datasetWithDenseFeature)
|
||||
assert(model.summary.totalIterations <= maxIter)
|
||||
}
|
||||
Seq("auto", "normal").foreach { solver =>
|
||||
val trainer = new LinearRegression().setSolver(solver)
|
||||
val model = trainer.fit(datasetWithDenseFeature)
|
||||
assert(model.summary.totalIterations === 0)
|
||||
}
|
||||
}
|
||||
|
||||
test("linear regression with weighted samples") {
|
||||
val sqlContext = spark.sqlContext
|
||||
import sqlContext.implicits._
|
||||
|
|
|
@ -1119,7 +1119,8 @@ class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
|
|||
def objectiveHistory(self):
|
||||
"""
|
||||
Objective function (scaled loss + regularization) at each
|
||||
iteration.
|
||||
iteration. It contains one more element, the initial state,
|
||||
than number of iterations.
|
||||
"""
|
||||
return self._call_java("objectiveHistory")
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class TrainingSummaryTest(SparkSessionTestCase):
|
|||
self.assertTrue(model.hasSummary)
|
||||
s = model.summary
|
||||
# test that api is callable and returns expected types
|
||||
self.assertGreater(s.totalIterations, 0)
|
||||
self.assertEqual(s.totalIterations, 0)
|
||||
self.assertTrue(isinstance(s.predictions, DataFrame))
|
||||
self.assertEqual(s.predictionCol, "prediction")
|
||||
self.assertEqual(s.labelCol, "label")
|
||||
|
|
Loading…
Reference in a new issue