[SPARK-14712][ML] LogisticRegressionModel.toString should summarize model

## What changes were proposed in this pull request?

[SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712)
spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark.

## How was this patch tested?

LogisticRegressionSuite.scala
Python doctest in pyspark.ml.classification.py

Author: bravo-zhang <mzhang1230@gmail.com>

Closes #18826 from bravo-zhang/spark-14712.
This commit is contained in:
bravo-zhang 2018-06-28 12:40:39 -07:00 committed by Holden Karau
parent 5b05966488
commit 524827f062
4 changed files with 19 additions and 0 deletions

View file

@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] (
*/
@Since("1.6.0")
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
override def toString: String = {
s"LogisticRegressionModel: " +
s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
}
}

View file

@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(model.getFamily === family)
}
}
test("toString") {
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3"
assert(model.toString === expected)
}
}
object LogisticRegressionSuite {

View file

@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
True
>>> blorModel.intercept == model2.intercept
True
>>> model2
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
.. versionadded:: 1.3.0
"""
@ -562,6 +564,9 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
def __repr__(self):
return self._call_java("toString")
class LogisticRegressionSummary(JavaWrapper):
"""

View file

@ -258,6 +258,9 @@ class LogisticRegressionModel(LinearClassificationModel):
model.setThreshold(threshold)
return model
def __repr__(self):
return self._call_java("toString")
class LogisticRegressionWithSGD(object):
"""