[SPARK-29801][ML] ML models unify toString method

### What changes were proposed in this pull request?
1,ML models should extend toString method to expose basic information.
Current some algs (GBT/RF/LoR) had done this, while others not yet.
2,add `val numFeatures` in `BisectingKMeansModel`/`GaussianMixtureModel`/`KMeansModel`/`AFTSurvivalRegressionModel`/`IsotonicRegressionModel`

### Why are the changes needed?
ML models should extend toString method to expose basic information.

### Does this PR introduce any user-facing change?
yes

### How was this patch tested?
existing testsuites

Closes #26439 from zhengruifeng/models_toString.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
zhengruifeng 2019-11-11 11:03:26 -08:00 committed by Dongjoon Hyun
parent cceb2d6f11
commit 76e5294bb6
63 changed files with 340 additions and 27 deletions

View file

@ -237,7 +237,8 @@ class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
s"DecisionTreeClassificationModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
s"numClasses=$numClasses, numFeatures=$numFeatures"
}
/**

View file

@ -340,7 +340,8 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def toString: String = {
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
s"GBTClassificationModel: uid = $uid, numTrees=$numTrees, numClasses=$numClasses, " +
s"numFeatures=$numFeatures"
}
/**

View file

@ -326,6 +326,10 @@ class LinearSVCModel private[classification] (
@Since("2.2.0")
override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
@Since("3.0.0")
override def toString: String = {
s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
}
}

View file

@ -1181,8 +1181,7 @@ class LogisticRegressionModel private[spark] (
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
override def toString: String = {
s"LogisticRegressionModel: " +
s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
}
}

View file

@ -323,6 +323,12 @@ class MultilayerPerceptronClassificationModel private[ml] (
override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features)
override def numClasses: Int = layers.last
@Since("3.0.0")
override def toString: String = {
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${layers.length}, " +
s"numClasses=$numClasses, numFeatures=$numFeatures"
}
}
@Since("2.0.0")

View file

@ -359,7 +359,8 @@ class NaiveBayesModel private[ml] (
@Since("1.5.0")
override def toString: String = {
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
s"NaiveBayesModel: uid=$uid, modelType=${$(modelType)}, numClasses=$numClasses, " +
s"numFeatures=$numFeatures"
}
@Since("1.6.0")

View file

@ -257,6 +257,12 @@ final class OneVsRestModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"OneVsRestModel: uid=$uid, classifier=${$(classifier)}, numClasses=$numClasses, " +
s"numFeatures=$numFeatures"
}
}
@Since("2.0.0")

View file

@ -260,7 +260,8 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
s"RandomForestClassificationModel: uid=$uid, numTrees=$getNumTrees, numClasses=$numClasses, " +
s"numFeatures=$numFeatures"
}
/**

View file

@ -91,6 +91,9 @@ class BisectingKMeansModel private[ml] (
extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
with HasTrainingSummary[BisectingKMeansSummary] {
@Since("3.0.0")
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
@ -145,6 +148,12 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"BisectingKMeansModel: uid=$uid, k=${parentModel.k}, distanceMeasure=${$(distanceMeasure)}, " +
s"numFeatures=$numFeatures"
}
/**
* Gets summary of model on training set. An exception is
* thrown if `hasSummary` is false.

View file

@ -89,6 +89,9 @@ class GaussianMixtureModel private[ml] (
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
with HasTrainingSummary[GaussianMixtureSummary] {
@Since("3.0.0")
lazy val numFeatures: Int = gaussians.head.mean.size
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@ -186,6 +189,11 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"GaussianMixtureModel: uid=$uid, k=${weights.length}, numFeatures=$numFeatures"
}
/**
* Gets summary of model on training set. An exception is
* thrown if `hasSummary` is false.

View file

@ -108,6 +108,9 @@ class KMeansModel private[ml] (
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
with HasTrainingSummary[KMeansSummary] {
@Since("3.0.0")
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
@ -153,6 +156,12 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: GeneralMLWriter = new GeneralMLWriter(this)
@Since("3.0.0")
override def toString: String = {
s"KMeansModel: uid=$uid, k=${parentModel.k}, distanceMeasure=${$(distanceMeasure)}, " +
s"numFeatures=$numFeatures"
}
/**
* Gets summary of model on training set. An exception is
* thrown if `hasSummary` is false.

View file

@ -620,6 +620,11 @@ class LocalLDAModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"LocalLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
}
}
@ -783,6 +788,11 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
@Since("3.0.0")
override def toString: String = {
s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
}
}

View file

@ -130,6 +130,12 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"BinaryClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"numBins=${$(numBins)}"
}
}
@Since("1.6.0")

View file

@ -120,6 +120,12 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
throw new IllegalArgumentException(s"No support for metric $mn, distance $dm")
}
}
@Since("3.0.0")
override def toString: String = {
s"ClusteringEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"distanceMeasure=${$(distanceMeasure)}"
}
}

View file

@ -184,6 +184,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"MulticlassClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"metricLabel=${$(metricLabel)}, beta=${$(beta)}, eps=${$(eps)}"
}
}
@Since("1.6.0")

View file

@ -121,6 +121,12 @@ class MultilabelClassificationEvaluator (override val uid: String)
}
override def copy(extra: ParamMap): MultilabelClassificationEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"MultilabelClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"metricLabel=${$(metricLabel)}"
}
}

View file

@ -105,6 +105,11 @@ class RankingEvaluator (override val uid: String)
override def isLargerBetter: Boolean = true
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"RankingEvaluator: uid=$uid, metricName=${$(metricName)}, k=${$(k)}"
}
}

View file

@ -124,6 +124,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"RegressionEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"throughOrigin=${$(throughOrigin)}"
}
}
@Since("1.6.0")

View file

@ -204,6 +204,13 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.1")
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"Binarizer: uid=$uid" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -106,6 +106,11 @@ class BucketedRandomProjectionLSHModel private[ml](
override def write: MLWriter = {
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
}
@Since("3.0.0")
override def toString: String = {
s"BucketedRandomProjectionLSHModel: uid=$uid, numHashTables=${$(numHashTables)}"
}
}
/**

View file

@ -215,6 +215,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
@Since("3.0.0")
override def toString: String = {
s"Bucketizer: uid=$uid" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -316,6 +316,11 @@ final class ChiSqSelectorModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new ChiSqSelectorModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"ChiSqSelectorModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}"
}
}
@Since("1.6.0")

View file

@ -307,7 +307,7 @@ class CountVectorizerModel(
}
val dictBr = broadcastDict.get
val minTf = $(minTF)
val vectorizer = udf { (document: Seq[String]) =>
val vectorizer = udf { document: Seq[String] =>
val termCounts = new OpenHashMap[Int, Double]
var tokenCount = 0L
document.foreach { term =>
@ -344,6 +344,11 @@ class CountVectorizerModel(
@Since("1.6.0")
override def write: MLWriter = new CountVectorizerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"CountVectorizerModel: uid=$uid, vocabularySize=${vocabulary.length}"
}
}
@Since("1.6.0")

View file

@ -74,6 +74,11 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT
@Since("3.0.0")
override def toString: String = {
s"DCT: uid=$uid, inverse=$inverse"
}
}
@Since("1.6.0")

View file

@ -81,6 +81,12 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
}
override protected def outputDataType: DataType = new VectorUDT()
@Since("3.0.0")
override def toString: String = {
s"ElementwiseProduct: uid=$uid" +
get(scalingVec).map(v => s", vectorSize=${v.size}").getOrElse("")
}
}
@Since("2.0.0")

View file

@ -22,7 +22,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, StringArrayParam}
import org.apache.spark.ml.param.{ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCols, HasNumFeatures, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
@ -199,6 +199,13 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
@Since("3.0.0")
override def toString: String = {
s"FeatureHasher: uid=$uid, numFeatures=${$(numFeatures)}" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(categoricalCols).map(c => s", numCategoricalCols=${c.length}").getOrElse("")
}
}
@Since("2.3.0")

View file

@ -127,6 +127,11 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.1")
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}"
}
}
@Since("1.6.0")

View file

@ -175,9 +175,13 @@ class IDFModel private[ml] (
@Since("3.0.0")
def numDocs: Long = idfModel.numDocs
@Since("1.6.0")
override def write: MLWriter = new IDFModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"IDFModel: uid=$uid, numDocs=$numDocs"
}
}
@Since("1.6.0")

View file

@ -274,6 +274,13 @@ class ImputerModel private[ml] (
@Since("2.2.0")
override def write: MLWriter = new ImputerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"ImputerModel: uid=$uid, strategy=${$(strategy)}, missingValue=${$(missingValue)}" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
}
}

View file

@ -218,6 +218,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
@Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"Interaction: uid=$uid" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -140,6 +140,11 @@ class MaxAbsScalerModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new MaxAbsScalerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"MaxAbsScalerModel: uid=$uid, numFeatures=${maxAbs.size}"
}
}
@Since("2.0.0")

View file

@ -96,6 +96,11 @@ class MinHashLSHModel private[ml](
@Since("2.1.0")
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"MinHashLSHModel: uid=$uid, numHashTables=${$(numHashTables)}"
}
}
/**

View file

@ -226,6 +226,12 @@ class MinMaxScalerModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new MinMaxScalerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"MinMaxScalerModel: uid=$uid, numFeatures=${originalMin.size}, min=${$(min)}, " +
s"max=${$(max)}"
}
}
@Since("1.6.0")

View file

@ -70,6 +70,11 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
@Since("3.0.0")
override def toString: String = {
s"NGram: uid=$uid, n=${$(n)}"
}
}
@Since("1.6.0")

View file

@ -65,6 +65,11 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT()
@Since("3.0.0")
override def toString: String = {
s"Normalizer: uid=$uid, p=${$(p)}"
}
}
@Since("1.6.0")

View file

@ -376,6 +376,13 @@ class OneHotEncoderModel private[ml] (
@Since("3.0.0")
override def write: MLWriter = new OneHotEncoderModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"OneHotEncoderModel: uid=$uid, dropLast=${$(dropLast)}, handleInvalid=${$(handleInvalid)}" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
}
}
@Since("3.0.0")

View file

@ -179,6 +179,11 @@ class PCAModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new PCAModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"PCAModel: uid=$uid, k=${$(k)}"
}
}
@Since("1.6.0")

View file

@ -77,6 +77,11 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
@Since("1.4.1")
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"PolynomialExpansion: uid=$uid, degree=${$(degree)}"
}
}
/**

View file

@ -320,7 +320,10 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
@Since("2.0.0")
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
override def toString: String = {
s"RFormula: uid=$uid" +
get(formula).map(f => s", formula = $f").getOrElse("")
}
}
@Since("2.0.0")
@ -376,7 +379,9 @@ class RFormulaModel private[feature](
}
@Since("2.0.0")
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
override def toString: String = {
s"RFormulaModel: uid=$uid, resolvedFormula=$resolvedFormula"
}
private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label

View file

@ -251,6 +251,12 @@ class RobustScalerModel private[ml] (
}
override def write: MLWriter = new RobustScalerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"RobustScalerModel: uid=$uid, numFeatures=${median.size}, " +
s"withCentering=${$(withCentering)}, withScaling=${$(withScaling)}"
}
}
@Since("3.0.0")

View file

@ -90,6 +90,11 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@Since("1.6.0")
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"SQLTransformer: uid=$uid, statement=${$(statement)}"
}
}
@Since("1.6.0")

View file

@ -184,6 +184,12 @@ class StandardScalerModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new StandardScalerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"StandardScalerModel: uid=$uid, numFeatures=${mean.size}, withMean=${$(withMean)}, " +
s"withStd=${$(withStd)}"
}
}
@Since("1.6.0")

View file

@ -156,6 +156,12 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("1.5.0")
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"StopWordsRemover: uid=$uid, numStopWords=${$(stopWords).length}, locale=${$(locale)}, " +
s"caseSensitive=${$(caseSensitive)}"
}
}
@Since("1.6.0")

View file

@ -412,7 +412,7 @@ class StringIndexerModel (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
var (inputColNames, outputColNames) = getInOutCols()
val (inputColNames, outputColNames) = getInOutCols()
val outputColumns = new Array[Column](outputColNames.length)
// Skips invalid rows if `handleInvalid` is set to `StringIndexer.SKIP_INVALID`.
@ -473,6 +473,14 @@ class StringIndexerModel (
@Since("1.6.0")
override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"StringIndexerModel: uid=$uid, handleInvalid=${$(handleInvalid)}" +
get(stringOrderType).map(t => s", stringOrderType=$t").getOrElse("") +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -175,6 +175,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.1")
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"VectorAssembler: uid=$uid, handleInvalid=${$(handleInvalid)}" +
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -428,7 +428,7 @@ class VectorIndexerModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
val transformUDF = udf { vector: Vector => transformFunc(vector) }
val newCol = transformUDF(dataset($(inputCol)))
val ds = dataset.withColumn($(outputCol), newCol, newField.metadata)
if (getHandleInvalid == VectorIndexer.SKIP_INVALID) {
@ -506,6 +506,11 @@ class VectorIndexerModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new VectorIndexerModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"VectorIndexerModel: uid=$uid, numFeatures=$numFeatures, handleInvalid=${$(handleInvalid)}"
}
}
@Since("1.6.0")

View file

@ -176,6 +176,11 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
@Since("2.3.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"VectorSizeHint: uid=$uid, size=${$(size)}, handleInvalid=${$(handleInvalid)}"
}
}
@Since("2.3.0")

View file

@ -159,6 +159,12 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"VectorSlicer: uid=$uid" +
get(indices).map(i => s", numSelectedFeatures=${i.length}").getOrElse("")
}
}
@Since("1.6.0")

View file

@ -324,6 +324,12 @@ class Word2VecModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new Word2VecModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"Word2VecModel: uid=$uid, numWords=${wordVectors.wordIndex.size}, " +
s"vectorSize=${$(vectorSize)}"
}
}
@Since("1.6.0")

View file

@ -310,6 +310,11 @@ class FPGrowthModel private[ml] (
@Since("2.2.0")
override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
}
}
@Since("2.2.0")

View file

@ -338,6 +338,11 @@ class ALSModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"ALSModel: uid=$uid, rank=$rank"
}
/**
* Returns top `numItems` items recommended for each user, for all users.
* @param numItems max number of recommendations for each user

View file

@ -311,6 +311,9 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0") val scale: Double)
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
@Since("3.0.0")
lazy val numFeatures: Int = coefficients.size
/** @group setParam */
@Since("1.6.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@ -386,6 +389,11 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0")
override def write: MLWriter =
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
@Since("3.0.0")
override def toString: String = {
s"AFTSurvivalRegressionModel: uid=$uid, numFeatures=$numFeatures"
}
}
@Since("1.6.0")

View file

@ -243,7 +243,8 @@ class DecisionTreeRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes"
s"DecisionTreeRegressionModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
s"numFeatures=$numFeatures"
}
/**

View file

@ -302,7 +302,7 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def toString: String = {
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
s"GBTRegressionModel: uid=$uid, numTrees=$numTrees, numFeatures=$numFeatures"
}
/**

View file

@ -1106,6 +1106,12 @@ class GeneralizedLinearRegressionModel private[ml] (
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
override val numFeatures: Int = coefficients.size
@Since("3.0.0")
override def toString: String = {
s"GeneralizedLinearRegressionModel: uid=$uid, family=${$(family)}, link=${$(link)}, " +
s"numFeatures=$numFeatures"
}
}
@Since("2.0.0")

View file

@ -259,6 +259,14 @@ class IsotonicRegressionModel private[ml] (
@Since("1.6.0")
override def write: MLWriter =
new IsotonicRegressionModelWriter(this)
@Since("3.0.0")
val numFeatures: Int = 1
@Since("3.0.0")
override def toString: String = {
s"IsotonicRegressionModel: uid=$uid, numFeatures=$numFeatures"
}
}
@Since("1.6.0")

View file

@ -702,6 +702,11 @@ class LinearRegressionModel private[ml] (
*/
@Since("1.6.0")
override def write: GeneralMLWriter = new GeneralMLWriter(this)
@Since("3.0.0")
override def toString: String = {
s"LinearRegressionModel: uid=$uid, numFeatures=$numFeatures"
}
}
/** A writer for LinearRegression that handles the "internal" (or default) format */

View file

@ -235,7 +235,7 @@ class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
s"RandomForestRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures"
}
/**

View file

@ -323,6 +323,11 @@ class CrossValidatorModel private[ml] (
override def write: CrossValidatorModel.CrossValidatorModelWriter = {
new CrossValidatorModel.CrossValidatorModelWriter(this)
}
@Since("3.0.0")
override def toString: String = {
s"CrossValidatorModel: uid=$uid, bestModel=$bestModel, numFolds=${$(numFolds)}"
}
}
@Since("1.6.0")

View file

@ -140,7 +140,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val collectSubModelsParam = $(collectSubModels)
var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
val subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
Some(Array.fill[Model[_]](epm.length)(null))
} else None
@ -314,6 +314,11 @@ class TrainValidationSplitModel private[ml] (
override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = {
new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
}
@Since("3.0.0")
override def toString: String = {
s"TrainValidationSplitModel: uid=$uid, bestModel=$bestModel, trainRatio=${$(trainRatio)}"
}
}
@Since("2.0.0")

View file

@ -2767,7 +2767,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
test("toString") {
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3"
val expected = "LogisticRegressionModel: uid=logReg, numClasses=2, numFeatures=3"
assert(model.toString === expected)
}
}

View file

@ -594,7 +594,7 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
>>> blorModel.intercept == model2.intercept
True
>>> model2
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2
.. versionadded:: 1.3.0
"""
@ -1146,7 +1146,7 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
>>> model.numClasses
2
>>> print(model.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
DecisionTreeClassificationModel...depth=1, numNodes=3...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
@ -1183,7 +1183,7 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
>>> model3 = dt3.fit(td3)
>>> print(model3.toDebugString)
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
DecisionTreeClassificationModel...depth=1, numNodes=3...
.. versionadded:: 1.4.0
"""
@ -1394,7 +1394,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
>>> model.transform(test1).head().prediction
1.0
>>> model.trees
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
[DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]
>>> rfc_path = temp_path + "/rfc"
>>> rf.save(rfc_path)
>>> rf2 = RandomForestClassifier.load(rfc_path)
@ -1651,7 +1651,7 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
>>> model.totalNumNodes
15
>>> print(model.toDebugString)
GBTClassificationModel (uid=...)...with 5 trees...
GBTClassificationModel...numTrees=5...
>>> gbtc_path = temp_path + "gbtc"
>>> gbt.save(gbtc_path)
>>> gbt2 = GBTClassifier.load(gbtc_path)
@ -1665,7 +1665,7 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
>>> model.treeWeights == model2.treeWeights
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
... ["indexed", "features"])
>>> model.evaluateEachIteration(validation)

View file

@ -800,7 +800,7 @@ class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLW
>>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
>>> model3 = dt3.fit(df3)
>>> print(model3.toDebugString)
DecisionTreeRegressionModel (uid=...) of depth 1 with 3 nodes...
DecisionTreeRegressionModel...depth=1, numNodes=3...
.. versionadded:: 1.4.0
"""
@ -1018,7 +1018,7 @@ class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLW
>>> model.numFeatures
1
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
>>> model.getNumTrees
2
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
@ -1265,7 +1265,7 @@ class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLRea
>>> model.treeWeights == model2.treeWeights
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
... ["label", "features"])
>>> model.evaluateEachIteration(validation, "squared")