[SPARK-14712][ML] LogisticRegressionModel.toString should summarize model
## What changes were proposed in this pull request? [SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712) spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark. ## How was this patch tested? LogisticRegressionSuite.scala Python doctest in pyspark.ml.classification.py Author: bravo-zhang <mzhang1230@gmail.com> Closes #18826 from bravo-zhang/spark-14712.
This commit is contained in:
parent
5b05966488
commit
524827f062
|
@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] (
|
|||
*/
|
||||
@Since("1.6.0")
|
||||
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
|
||||
|
||||
override def toString: String = {
|
||||
s"LogisticRegressionModel: " +
|
||||
s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
assert(model.getFamily === family)
|
||||
}
|
||||
}
|
||||
|
||||
test("toString") {
|
||||
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
|
||||
val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3"
|
||||
assert(model.toString === expected)
|
||||
}
|
||||
}
|
||||
|
||||
object LogisticRegressionSuite {
|
||||
|
|
|
@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
True
|
||||
>>> blorModel.intercept == model2.intercept
|
||||
True
|
||||
>>> model2
|
||||
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
"""
|
||||
|
@ -562,6 +564,9 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
|
|||
java_blr_summary = self._call_java("evaluate", dataset)
|
||||
return BinaryLogisticRegressionSummary(java_blr_summary)
|
||||
|
||||
def __repr__(self):
|
||||
return self._call_java("toString")
|
||||
|
||||
|
||||
class LogisticRegressionSummary(JavaWrapper):
|
||||
"""
|
||||
|
|
|
@ -258,6 +258,9 @@ class LogisticRegressionModel(LinearClassificationModel):
|
|||
model.setThreshold(threshold)
|
||||
return model
|
||||
|
||||
def __repr__(self):
|
||||
return self._call_java("toString")
|
||||
|
||||
|
||||
class LogisticRegressionWithSGD(object):
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue