diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d1350640c..1f5976c592 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -594,7 +594,7 @@ class LogisticRegression @Since("1.2.0") ( Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity) } if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() - return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty) + return createModel(dataset, numClasses, coefMatrix, interceptVec, Array(0.0)) } if (!$(fitIntercept) && isConstantLabel) { @@ -1545,13 +1545,19 @@ sealed trait LogisticRegressionSummary extends Serializable { */ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { - /** objective function (scaled loss + regularization) at each iteration. */ + /** + * objective function (scaled loss + regularization) at each iteration. + * It contains one more element, the initial state, than number of iterations. + */ @Since("1.5.0") def objectiveHistory: Array[Double] /** Number of training iterations. */ @Since("1.5.0") - def totalIterations: Int = objectiveHistory.length + def totalIterations: Int = { + assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.") + objectiveHistory.length - 1 + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8b6ede3bb3..d9f09c0972 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -433,7 +433,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String Vectors.dense(Array.fill(dim)(1.0)) } - val (parameters, objectiveHistory) = if ($(blockSize) == 1) { + val (parameters, objectiveHistory) = if ($(blockSize) == 1) { trainOnRows(instances, yMean, yStd, featuresMean, featuresStd, initialValues, regularization, optimizer) } else { @@ -939,8 +939,10 @@ class LinearRegressionTrainingSummary private[regression] ( * @see `LinearRegression.solver` */ @Since("1.5.0") - val totalIterations = objectiveHistory.length - + val totalIterations = { + assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.") + objectiveHistory.length - 1 + } } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 49ac493394..7c63a8755b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -143,6 +143,6 @@ public class JavaLogisticRegressionSuite extends SharedSparkSession { LogisticRegressionModel model = lr.fit(dataset); LogisticRegressionTrainingSummary summary = model.summary(); - Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length); + Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length - 1); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 25e9697d64..30c21d8b06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -266,6 +266,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary]) assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(blorModel.summary.totalIterations == 1) + assert(blorModel.binarySummary.totalIterations == 1) val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset) assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary]) @@ -279,6 +281,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { mlorModel.summary.asBinary } } + assert(mlorModel.summary.totalIterations == 1) val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset) assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) @@ -2570,7 +2573,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { rows.map(_.getDouble(0)).toArray === binaryExpected } } - assert(model2.summary.totalIterations === 1) + assert(model2.summary.totalIterations === 0) val lr3 = new LogisticRegression().setFamily("multinomial") val model3 = lr3.fit(smallMultinomialDataset) @@ -2585,7 +2588,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { rows.map(_.getDouble(0)).toArray === multinomialExpected } } - assert(model4.summary.totalIterations === 1) + assert(model4.summary.totalIterations === 0) } test("binary logistic regression with all labels the same") { @@ -2605,6 +2608,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) assert(allZeroInterceptModel.intercept === Double.NegativeInfinity) assert(allZeroInterceptModel.summary.totalIterations === 0) + assert(allZeroInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) val allOneInterceptModel = lrIntercept .setLabelCol("oneLabel") @@ -2612,6 +2616,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) assert(allOneInterceptModel.intercept === Double.PositiveInfinity) assert(allOneInterceptModel.summary.totalIterations === 0) + assert(allOneInterceptModel.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) // fitIntercept=false val lrNoIntercept = new LogisticRegression() @@ -2647,6 +2652,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(pred === 4.0) } assert(model.summary.totalIterations === 0) + assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) // force the model to be trained with only one class val constantZeroData = Seq( @@ -2660,7 +2666,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) } - assert(modelZeroLabel.summary.totalIterations > 0) + assert(modelZeroLabel.summary.totalIterations === 0) // ensure that the correct value is predicted when numClasses passed through metadata val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata() @@ -2675,6 +2681,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(pred === 4.0) } require(modelWithMetadata.summary.totalIterations === 0) + assert(model.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) } test("compressed storage for constant label") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c4a94ff2d6..fb70883bff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -761,6 +761,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe .fit(datasetWithWeightConstantLabel) if (fitIntercept) { assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + assert(model1.summary.totalIterations === 0) } val model2 = new LinearRegression() .setFitIntercept(fitIntercept) @@ -768,6 +769,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe .setSolver("l-bfgs") .fit(datasetWithWeightZeroLabel) assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + assert(model2.summary.totalIterations === 0) } } } @@ -940,6 +942,19 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe } } + test("linear regression training summary totalIterations") { + Seq(1, 5, 10, 20).foreach { maxIter => + val trainer = new LinearRegression().setSolver("l-bfgs").setMaxIter(maxIter) + val model = trainer.fit(datasetWithDenseFeature) + assert(model.summary.totalIterations <= maxIter) + } + Seq("auto", "normal").foreach { solver => + val trainer = new LinearRegression().setSolver(solver) + val model = trainer.fit(datasetWithDenseFeature) + assert(model.summary.totalIterations === 0) + } + } + test("linear regression with weighted samples") { val sqlContext = spark.sqlContext import sqlContext.implicits._ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 734c393db2..3f3699ce53 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1119,7 +1119,8 @@ class LogisticRegressionTrainingSummary(LogisticRegressionSummary): def objectiveHistory(self): """ Objective function (scaled loss + regularization) at each - iteration. + iteration. It contains one more element, the initial state, + than number of iterations. """ return self._call_java("objectiveHistory") diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index b5054095d1..ac944d8397 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -42,7 +42,7 @@ class TrainingSummaryTest(SparkSessionTestCase): self.assertTrue(model.hasSummary) s = model.summary # test that api is callable and returns expected types - self.assertGreater(s.totalIterations, 0) + self.assertEqual(s.totalIterations, 0) self.assertTrue(isinstance(s.predictions, DataFrame)) self.assertEqual(s.predictionCol, "prediction") self.assertEqual(s.labelCol, "label")