[SPARK-29801][ML] ML models unify toString method
### What changes were proposed in this pull request? 1,ML models should extend toString method to expose basic information. Current some algs (GBT/RF/LoR) had done this, while others not yet. 2,add `val numFeatures` in `BisectingKMeansModel`/`GaussianMixtureModel`/`KMeansModel`/`AFTSurvivalRegressionModel`/`IsotonicRegressionModel` ### Why are the changes needed? ML models should extend toString method to expose basic information. ### Does this PR introduce any user-facing change? yes ### How was this patch tested? existing testsuites Closes #26439 from zhengruifeng/models_toString. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
cceb2d6f11
commit
76e5294bb6
|
@ -237,7 +237,8 @@ class DecisionTreeClassificationModel private[ml] (
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
|
s"DecisionTreeClassificationModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
|
||||||
|
s"numClasses=$numClasses, numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -340,7 +340,8 @@ class GBTClassificationModel private[ml](
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
|
s"GBTClassificationModel: uid = $uid, numTrees=$numTrees, numClasses=$numClasses, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -326,6 +326,10 @@ class LinearSVCModel private[classification] (
|
||||||
@Since("2.2.0")
|
@Since("2.2.0")
|
||||||
override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
|
override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1181,8 +1181,7 @@ class LogisticRegressionModel private[spark] (
|
||||||
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
|
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"LogisticRegressionModel: " +
|
s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
|
||||||
s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -323,6 +323,12 @@ class MultilayerPerceptronClassificationModel private[ml] (
|
||||||
override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features)
|
override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features)
|
||||||
|
|
||||||
override def numClasses: Int = layers.last
|
override def numClasses: Int = layers.last
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${layers.length}, " +
|
||||||
|
s"numClasses=$numClasses, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -359,7 +359,8 @@ class NaiveBayesModel private[ml] (
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
|
s"NaiveBayesModel: uid=$uid, modelType=${$(modelType)}, numClasses=$numClasses, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -257,6 +257,12 @@ final class OneVsRestModel private[ml] (
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
|
override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"OneVsRestModel: uid=$uid, classifier=${$(classifier)}, numClasses=$numClasses, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -260,7 +260,8 @@ class RandomForestClassificationModel private[ml] (
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
|
s"RandomForestClassificationModel: uid=$uid, numTrees=$getNumTrees, numClasses=$numClasses, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -91,6 +91,9 @@ class BisectingKMeansModel private[ml] (
|
||||||
extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
|
extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
|
||||||
with HasTrainingSummary[BisectingKMeansSummary] {
|
with HasTrainingSummary[BisectingKMeansSummary] {
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def copy(extra: ParamMap): BisectingKMeansModel = {
|
override def copy(extra: ParamMap): BisectingKMeansModel = {
|
||||||
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
|
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
|
||||||
|
@ -145,6 +148,12 @@ class BisectingKMeansModel private[ml] (
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
|
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"BisectingKMeansModel: uid=$uid, k=${parentModel.k}, distanceMeasure=${$(distanceMeasure)}, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets summary of model on training set. An exception is
|
* Gets summary of model on training set. An exception is
|
||||||
* thrown if `hasSummary` is false.
|
* thrown if `hasSummary` is false.
|
||||||
|
|
|
@ -89,6 +89,9 @@ class GaussianMixtureModel private[ml] (
|
||||||
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
|
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
|
||||||
with HasTrainingSummary[GaussianMixtureSummary] {
|
with HasTrainingSummary[GaussianMixtureSummary] {
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
lazy val numFeatures: Int = gaussians.head.mean.size
|
||||||
|
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
@Since("2.1.0")
|
@Since("2.1.0")
|
||||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||||
|
@ -186,6 +189,11 @@ class GaussianMixtureModel private[ml] (
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
|
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"GaussianMixtureModel: uid=$uid, k=${weights.length}, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets summary of model on training set. An exception is
|
* Gets summary of model on training set. An exception is
|
||||||
* thrown if `hasSummary` is false.
|
* thrown if `hasSummary` is false.
|
||||||
|
|
|
@ -108,6 +108,9 @@ class KMeansModel private[ml] (
|
||||||
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
|
extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
|
||||||
with HasTrainingSummary[KMeansSummary] {
|
with HasTrainingSummary[KMeansSummary] {
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
lazy val numFeatures: Int = parentModel.clusterCenters.head.size
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): KMeansModel = {
|
override def copy(extra: ParamMap): KMeansModel = {
|
||||||
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
|
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
|
||||||
|
@ -153,6 +156,12 @@ class KMeansModel private[ml] (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: GeneralMLWriter = new GeneralMLWriter(this)
|
override def write: GeneralMLWriter = new GeneralMLWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"KMeansModel: uid=$uid, k=${parentModel.k}, distanceMeasure=${$(distanceMeasure)}, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets summary of model on training set. An exception is
|
* Gets summary of model on training set. An exception is
|
||||||
* thrown if `hasSummary` is false.
|
* thrown if `hasSummary` is false.
|
||||||
|
|
|
@ -620,6 +620,11 @@ class LocalLDAModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
|
override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"LocalLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -783,6 +788,11 @@ class DistributedLDAModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
|
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -130,6 +130,12 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
|
||||||
|
|
||||||
@Since("1.4.1")
|
@Since("1.4.1")
|
||||||
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
|
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"BinaryClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
|
||||||
|
s"numBins=${$(numBins)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -120,6 +120,12 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
|
||||||
throw new IllegalArgumentException(s"No support for metric $mn, distance $dm")
|
throw new IllegalArgumentException(s"No support for metric $mn, distance $dm")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"ClusteringEvaluator: uid=$uid, metricName=${$(metricName)}, " +
|
||||||
|
s"distanceMeasure=${$(distanceMeasure)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -184,6 +184,12 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
|
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MulticlassClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
|
||||||
|
s"metricLabel=${$(metricLabel)}, beta=${$(beta)}, eps=${$(eps)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -121,6 +121,12 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): MultilabelClassificationEvaluator = defaultCopy(extra)
|
override def copy(extra: ParamMap): MultilabelClassificationEvaluator = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MultilabelClassificationEvaluator: uid=$uid, metricName=${$(metricName)}, " +
|
||||||
|
s"metricLabel=${$(metricLabel)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -105,6 +105,11 @@ class RankingEvaluator (override val uid: String)
|
||||||
override def isLargerBetter: Boolean = true
|
override def isLargerBetter: Boolean = true
|
||||||
|
|
||||||
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
|
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"RankingEvaluator: uid=$uid, metricName=${$(metricName)}, k=${$(k)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -124,6 +124,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
|
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"RegressionEvaluator: uid=$uid, metricName=${$(metricName)}, " +
|
||||||
|
s"throughOrigin=${$(throughOrigin)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -204,6 +204,13 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
|
|
||||||
@Since("1.4.1")
|
@Since("1.4.1")
|
||||||
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
|
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"Binarizer: uid=$uid" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -106,6 +106,11 @@ class BucketedRandomProjectionLSHModel private[ml](
|
||||||
override def write: MLWriter = {
|
override def write: MLWriter = {
|
||||||
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
|
new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"BucketedRandomProjectionLSHModel: uid=$uid, numHashTables=${$(numHashTables)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -215,6 +215,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
|
||||||
override def copy(extra: ParamMap): Bucketizer = {
|
override def copy(extra: ParamMap): Bucketizer = {
|
||||||
defaultCopy[Bucketizer](extra).setParent(parent)
|
defaultCopy[Bucketizer](extra).setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"Bucketizer: uid=$uid" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -316,6 +316,11 @@ final class ChiSqSelectorModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new ChiSqSelectorModelWriter(this)
|
override def write: MLWriter = new ChiSqSelectorModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"ChiSqSelectorModel: uid=$uid, numSelectedFeatures=${selectedFeatures.length}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -307,7 +307,7 @@ class CountVectorizerModel(
|
||||||
}
|
}
|
||||||
val dictBr = broadcastDict.get
|
val dictBr = broadcastDict.get
|
||||||
val minTf = $(minTF)
|
val minTf = $(minTF)
|
||||||
val vectorizer = udf { (document: Seq[String]) =>
|
val vectorizer = udf { document: Seq[String] =>
|
||||||
val termCounts = new OpenHashMap[Int, Double]
|
val termCounts = new OpenHashMap[Int, Double]
|
||||||
var tokenCount = 0L
|
var tokenCount = 0L
|
||||||
document.foreach { term =>
|
document.foreach { term =>
|
||||||
|
@ -344,6 +344,11 @@ class CountVectorizerModel(
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new CountVectorizerModelWriter(this)
|
override def write: MLWriter = new CountVectorizerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"CountVectorizerModel: uid=$uid, vocabularySize=${vocabulary.length}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -74,6 +74,11 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new VectorUDT
|
override protected def outputDataType: DataType = new VectorUDT
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"DCT: uid=$uid, inverse=$inverse"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -81,6 +81,12 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new VectorUDT()
|
override protected def outputDataType: DataType = new VectorUDT()
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"ElementwiseProduct: uid=$uid" +
|
||||||
|
get(scalingVec).map(v => s", vectorSize=${v.size}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.Transformer
|
import org.apache.spark.ml.Transformer
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
import org.apache.spark.ml.attribute.AttributeGroup
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, StringArrayParam}
|
import org.apache.spark.ml.param.{ParamMap, StringArrayParam}
|
||||||
import org.apache.spark.ml.param.shared.{HasInputCols, HasNumFeatures, HasOutputCol}
|
import org.apache.spark.ml.param.shared.{HasInputCols, HasNumFeatures, HasOutputCol}
|
||||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||||
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
|
import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
|
||||||
|
@ -199,6 +199,13 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
|
||||||
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
|
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
|
||||||
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
|
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"FeatureHasher: uid=$uid, numFeatures=${$(numFeatures)}" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(categoricalCols).map(c => s", numCategoricalCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.3.0")
|
@Since("2.3.0")
|
||||||
|
|
|
@ -127,6 +127,11 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
|
|
||||||
@Since("1.4.1")
|
@Since("1.4.1")
|
||||||
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
|
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -175,9 +175,13 @@ class IDFModel private[ml] (
|
||||||
@Since("3.0.0")
|
@Since("3.0.0")
|
||||||
def numDocs: Long = idfModel.numDocs
|
def numDocs: Long = idfModel.numDocs
|
||||||
|
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new IDFModelWriter(this)
|
override def write: MLWriter = new IDFModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"IDFModel: uid=$uid, numDocs=$numDocs"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -274,6 +274,13 @@ class ImputerModel private[ml] (
|
||||||
|
|
||||||
@Since("2.2.0")
|
@Since("2.2.0")
|
||||||
override def write: MLWriter = new ImputerModelWriter(this)
|
override def write: MLWriter = new ImputerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"ImputerModel: uid=$uid, strategy=${$(strategy)}, missingValue=${$(missingValue)}" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
|
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"Interaction: uid=$uid" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -140,6 +140,11 @@ class MaxAbsScalerModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new MaxAbsScalerModelWriter(this)
|
override def write: MLWriter = new MaxAbsScalerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MaxAbsScalerModel: uid=$uid, numFeatures=${maxAbs.size}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -96,6 +96,11 @@ class MinHashLSHModel private[ml](
|
||||||
|
|
||||||
@Since("2.1.0")
|
@Since("2.1.0")
|
||||||
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
|
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MinHashLSHModel: uid=$uid, numHashTables=${$(numHashTables)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -226,6 +226,12 @@ class MinMaxScalerModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new MinMaxScalerModelWriter(this)
|
override def write: MLWriter = new MinMaxScalerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"MinMaxScalerModel: uid=$uid, numFeatures=${originalMin.size}, min=${$(min)}, " +
|
||||||
|
s"max=${$(max)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -70,6 +70,11 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
override protected def outputDataType: DataType = new ArrayType(StringType, false)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"NGram: uid=$uid, n=${$(n)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -65,6 +65,11 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def outputDataType: DataType = new VectorUDT()
|
override protected def outputDataType: DataType = new VectorUDT()
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"Normalizer: uid=$uid, p=${$(p)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -376,6 +376,13 @@ class OneHotEncoderModel private[ml] (
|
||||||
|
|
||||||
@Since("3.0.0")
|
@Since("3.0.0")
|
||||||
override def write: MLWriter = new OneHotEncoderModelWriter(this)
|
override def write: MLWriter = new OneHotEncoderModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"OneHotEncoderModel: uid=$uid, dropLast=${$(dropLast)}, handleInvalid=${$(handleInvalid)}" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("3.0.0")
|
@Since("3.0.0")
|
||||||
|
|
|
@ -179,6 +179,11 @@ class PCAModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new PCAModelWriter(this)
|
override def write: MLWriter = new PCAModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"PCAModel: uid=$uid, k=${$(k)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -77,6 +77,11 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
|
||||||
|
|
||||||
@Since("1.4.1")
|
@Since("1.4.1")
|
||||||
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
|
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"PolynomialExpansion: uid=$uid, degree=${$(degree)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -320,7 +320,10 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
||||||
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
|
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
|
override def toString: String = {
|
||||||
|
s"RFormula: uid=$uid" +
|
||||||
|
get(formula).map(f => s", formula = $f").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
@ -376,7 +379,9 @@ class RFormulaModel private[feature](
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
|
override def toString: String = {
|
||||||
|
s"RFormulaModel: uid=$uid, resolvedFormula=$resolvedFormula"
|
||||||
|
}
|
||||||
|
|
||||||
private def transformLabel(dataset: Dataset[_]): DataFrame = {
|
private def transformLabel(dataset: Dataset[_]): DataFrame = {
|
||||||
val labelName = resolvedFormula.label
|
val labelName = resolvedFormula.label
|
||||||
|
|
|
@ -251,6 +251,12 @@ class RobustScalerModel private[ml] (
|
||||||
}
|
}
|
||||||
|
|
||||||
override def write: MLWriter = new RobustScalerModelWriter(this)
|
override def write: MLWriter = new RobustScalerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"RobustScalerModel: uid=$uid, numFeatures=${median.size}, " +
|
||||||
|
s"withCentering=${$(withCentering)}, withScaling=${$(withScaling)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("3.0.0")
|
@Since("3.0.0")
|
||||||
|
|
|
@ -90,6 +90,11 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
|
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"SQLTransformer: uid=$uid, statement=${$(statement)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -184,6 +184,12 @@ class StandardScalerModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new StandardScalerModelWriter(this)
|
override def write: MLWriter = new StandardScalerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"StandardScalerModel: uid=$uid, numFeatures=${mean.size}, withMean=${$(withMean)}, " +
|
||||||
|
s"withStd=${$(withStd)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -156,6 +156,12 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
|
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"StopWordsRemover: uid=$uid, numStopWords=${$(stopWords).length}, locale=${$(locale)}, " +
|
||||||
|
s"caseSensitive=${$(caseSensitive)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -412,7 +412,7 @@ class StringIndexerModel (
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
|
|
||||||
var (inputColNames, outputColNames) = getInOutCols()
|
val (inputColNames, outputColNames) = getInOutCols()
|
||||||
val outputColumns = new Array[Column](outputColNames.length)
|
val outputColumns = new Array[Column](outputColNames.length)
|
||||||
|
|
||||||
// Skips invalid rows if `handleInvalid` is set to `StringIndexer.SKIP_INVALID`.
|
// Skips invalid rows if `handleInvalid` is set to `StringIndexer.SKIP_INVALID`.
|
||||||
|
@ -473,6 +473,14 @@ class StringIndexerModel (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
|
override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"StringIndexerModel: uid=$uid, handleInvalid=${$(handleInvalid)}" +
|
||||||
|
get(stringOrderType).map(t => s", stringOrderType=$t").getOrElse("") +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("") +
|
||||||
|
get(outputCols).map(c => s", numOutputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -175,6 +175,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
||||||
|
|
||||||
@Since("1.4.1")
|
@Since("1.4.1")
|
||||||
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
|
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"VectorAssembler: uid=$uid, handleInvalid=${$(handleInvalid)}" +
|
||||||
|
get(inputCols).map(c => s", numInputCols=${c.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -428,7 +428,7 @@ class VectorIndexerModel private[ml] (
|
||||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||||
transformSchema(dataset.schema, logging = true)
|
transformSchema(dataset.schema, logging = true)
|
||||||
val newField = prepOutputField(dataset.schema)
|
val newField = prepOutputField(dataset.schema)
|
||||||
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
|
val transformUDF = udf { vector: Vector => transformFunc(vector) }
|
||||||
val newCol = transformUDF(dataset($(inputCol)))
|
val newCol = transformUDF(dataset($(inputCol)))
|
||||||
val ds = dataset.withColumn($(outputCol), newCol, newField.metadata)
|
val ds = dataset.withColumn($(outputCol), newCol, newField.metadata)
|
||||||
if (getHandleInvalid == VectorIndexer.SKIP_INVALID) {
|
if (getHandleInvalid == VectorIndexer.SKIP_INVALID) {
|
||||||
|
@ -506,6 +506,11 @@ class VectorIndexerModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new VectorIndexerModelWriter(this)
|
override def write: MLWriter = new VectorIndexerModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"VectorIndexerModel: uid=$uid, numFeatures=$numFeatures, handleInvalid=${$(handleInvalid)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -176,6 +176,11 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
|
||||||
|
|
||||||
@Since("2.3.0")
|
@Since("2.3.0")
|
||||||
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
|
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"VectorSizeHint: uid=$uid, size=${$(size)}, handleInvalid=${$(handleInvalid)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.3.0")
|
@Since("2.3.0")
|
||||||
|
|
|
@ -159,6 +159,12 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
|
||||||
|
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
|
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"VectorSlicer: uid=$uid" +
|
||||||
|
get(indices).map(i => s", numSelectedFeatures=${i.length}").getOrElse("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -324,6 +324,12 @@ class Word2VecModel private[ml] (
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new Word2VecModelWriter(this)
|
override def write: MLWriter = new Word2VecModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"Word2VecModel: uid=$uid, numWords=${wordVectors.wordIndex.size}, " +
|
||||||
|
s"vectorSize=${$(vectorSize)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -310,6 +310,11 @@ class FPGrowthModel private[ml] (
|
||||||
|
|
||||||
@Since("2.2.0")
|
@Since("2.2.0")
|
||||||
override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this)
|
override def write: MLWriter = new FPGrowthModel.FPGrowthModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.2.0")
|
@Since("2.2.0")
|
||||||
|
|
|
@ -338,6 +338,11 @@ class ALSModel private[ml] (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
|
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"ALSModel: uid=$uid, rank=$rank"
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns top `numItems` items recommended for each user, for all users.
|
* Returns top `numItems` items recommended for each user, for all users.
|
||||||
* @param numItems max number of recommendations for each user
|
* @param numItems max number of recommendations for each user
|
||||||
|
|
|
@ -311,6 +311,9 @@ class AFTSurvivalRegressionModel private[ml] (
|
||||||
@Since("1.6.0") val scale: Double)
|
@Since("1.6.0") val scale: Double)
|
||||||
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
|
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
lazy val numFeatures: Int = coefficients.size
|
||||||
|
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||||
|
@ -386,6 +389,11 @@ class AFTSurvivalRegressionModel private[ml] (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter =
|
override def write: MLWriter =
|
||||||
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
|
new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"AFTSurvivalRegressionModel: uid=$uid, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -243,7 +243,8 @@ class DecisionTreeRegressionModel private[ml] (
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes"
|
s"DecisionTreeRegressionModel: uid=$uid, depth=$depth, numNodes=$numNodes, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -302,7 +302,7 @@ class GBTRegressionModel private[ml](
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
|
s"GBTRegressionModel: uid=$uid, numTrees=$numTrees, numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1106,6 +1106,12 @@ class GeneralizedLinearRegressionModel private[ml] (
|
||||||
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
|
new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
|
||||||
|
|
||||||
override val numFeatures: Int = coefficients.size
|
override val numFeatures: Int = coefficients.size
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"GeneralizedLinearRegressionModel: uid=$uid, family=${$(family)}, link=${$(link)}, " +
|
||||||
|
s"numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -259,6 +259,14 @@ class IsotonicRegressionModel private[ml] (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: MLWriter =
|
override def write: MLWriter =
|
||||||
new IsotonicRegressionModelWriter(this)
|
new IsotonicRegressionModelWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
val numFeatures: Int = 1
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"IsotonicRegressionModel: uid=$uid, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -702,6 +702,11 @@ class LinearRegressionModel private[ml] (
|
||||||
*/
|
*/
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
override def write: GeneralMLWriter = new GeneralMLWriter(this)
|
override def write: GeneralMLWriter = new GeneralMLWriter(this)
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"LinearRegressionModel: uid=$uid, numFeatures=$numFeatures"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A writer for LinearRegression that handles the "internal" (or default) format */
|
/** A writer for LinearRegression that handles the "internal" (or default) format */
|
||||||
|
|
|
@ -235,7 +235,7 @@ class RandomForestRegressionModel private[ml] (
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
|
s"RandomForestRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures"
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -323,6 +323,11 @@ class CrossValidatorModel private[ml] (
|
||||||
override def write: CrossValidatorModel.CrossValidatorModelWriter = {
|
override def write: CrossValidatorModel.CrossValidatorModelWriter = {
|
||||||
new CrossValidatorModel.CrossValidatorModelWriter(this)
|
new CrossValidatorModel.CrossValidatorModelWriter(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"CrossValidatorModel: uid=$uid, bestModel=$bestModel, numFolds=${$(numFolds)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
|
|
|
@ -140,7 +140,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
|
||||||
|
|
||||||
val collectSubModelsParam = $(collectSubModels)
|
val collectSubModelsParam = $(collectSubModels)
|
||||||
|
|
||||||
var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
|
val subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
|
||||||
Some(Array.fill[Model[_]](epm.length)(null))
|
Some(Array.fill[Model[_]](epm.length)(null))
|
||||||
} else None
|
} else None
|
||||||
|
|
||||||
|
@ -314,6 +314,11 @@ class TrainValidationSplitModel private[ml] (
|
||||||
override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = {
|
override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = {
|
||||||
new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
|
new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Since("3.0.0")
|
||||||
|
override def toString: String = {
|
||||||
|
s"TrainValidationSplitModel: uid=$uid, bestModel=$bestModel, trainRatio=${$(trainRatio)}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
|
|
|
@ -2767,7 +2767,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
|
||||||
|
|
||||||
test("toString") {
|
test("toString") {
|
||||||
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
|
val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
|
||||||
val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3"
|
val expected = "LogisticRegressionModel: uid=logReg, numClasses=2, numFeatures=3"
|
||||||
assert(model.toString === expected)
|
assert(model.toString === expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -594,7 +594,7 @@ class LogisticRegression(JavaProbabilisticClassifier, _LogisticRegressionParams,
|
||||||
>>> blorModel.intercept == model2.intercept
|
>>> blorModel.intercept == model2.intercept
|
||||||
True
|
True
|
||||||
>>> model2
|
>>> model2
|
||||||
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
|
LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
"""
|
"""
|
||||||
|
@ -1146,7 +1146,7 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
|
||||||
>>> model.numClasses
|
>>> model.numClasses
|
||||||
2
|
2
|
||||||
>>> print(model.toDebugString)
|
>>> print(model.toDebugString)
|
||||||
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
|
DecisionTreeClassificationModel...depth=1, numNodes=3...
|
||||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||||
>>> model.predict(test0.head().features)
|
>>> model.predict(test0.head().features)
|
||||||
0.0
|
0.0
|
||||||
|
@ -1183,7 +1183,7 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifie
|
||||||
>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
|
>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
|
||||||
>>> model3 = dt3.fit(td3)
|
>>> model3 = dt3.fit(td3)
|
||||||
>>> print(model3.toDebugString)
|
>>> print(model3.toDebugString)
|
||||||
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
|
DecisionTreeClassificationModel...depth=1, numNodes=3...
|
||||||
|
|
||||||
.. versionadded:: 1.4.0
|
.. versionadded:: 1.4.0
|
||||||
"""
|
"""
|
||||||
|
@ -1394,7 +1394,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
|
||||||
>>> model.transform(test1).head().prediction
|
>>> model.transform(test1).head().prediction
|
||||||
1.0
|
1.0
|
||||||
>>> model.trees
|
>>> model.trees
|
||||||
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
|
[DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]
|
||||||
>>> rfc_path = temp_path + "/rfc"
|
>>> rfc_path = temp_path + "/rfc"
|
||||||
>>> rf.save(rfc_path)
|
>>> rf.save(rfc_path)
|
||||||
>>> rf2 = RandomForestClassifier.load(rfc_path)
|
>>> rf2 = RandomForestClassifier.load(rfc_path)
|
||||||
|
@ -1651,7 +1651,7 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
|
||||||
>>> model.totalNumNodes
|
>>> model.totalNumNodes
|
||||||
15
|
15
|
||||||
>>> print(model.toDebugString)
|
>>> print(model.toDebugString)
|
||||||
GBTClassificationModel (uid=...)...with 5 trees...
|
GBTClassificationModel...numTrees=5...
|
||||||
>>> gbtc_path = temp_path + "gbtc"
|
>>> gbtc_path = temp_path + "gbtc"
|
||||||
>>> gbt.save(gbtc_path)
|
>>> gbt.save(gbtc_path)
|
||||||
>>> gbt2 = GBTClassifier.load(gbtc_path)
|
>>> gbt2 = GBTClassifier.load(gbtc_path)
|
||||||
|
@ -1665,7 +1665,7 @@ class GBTClassifier(JavaProbabilisticClassifier, _GBTClassifierParams,
|
||||||
>>> model.treeWeights == model2.treeWeights
|
>>> model.treeWeights == model2.treeWeights
|
||||||
True
|
True
|
||||||
>>> model.trees
|
>>> model.trees
|
||||||
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
|
||||||
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
|
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
|
||||||
... ["indexed", "features"])
|
... ["indexed", "features"])
|
||||||
>>> model.evaluateEachIteration(validation)
|
>>> model.evaluateEachIteration(validation)
|
||||||
|
|
|
@ -800,7 +800,7 @@ class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLW
|
||||||
>>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
|
>>> dt3 = DecisionTreeRegressor(maxDepth=2, weightCol="weight", varianceCol="variance")
|
||||||
>>> model3 = dt3.fit(df3)
|
>>> model3 = dt3.fit(df3)
|
||||||
>>> print(model3.toDebugString)
|
>>> print(model3.toDebugString)
|
||||||
DecisionTreeRegressionModel (uid=...) of depth 1 with 3 nodes...
|
DecisionTreeRegressionModel...depth=1, numNodes=3...
|
||||||
|
|
||||||
.. versionadded:: 1.4.0
|
.. versionadded:: 1.4.0
|
||||||
"""
|
"""
|
||||||
|
@ -1018,7 +1018,7 @@ class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLW
|
||||||
>>> model.numFeatures
|
>>> model.numFeatures
|
||||||
1
|
1
|
||||||
>>> model.trees
|
>>> model.trees
|
||||||
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
|
||||||
>>> model.getNumTrees
|
>>> model.getNumTrees
|
||||||
2
|
2
|
||||||
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
||||||
|
@ -1265,7 +1265,7 @@ class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLRea
|
||||||
>>> model.treeWeights == model2.treeWeights
|
>>> model.treeWeights == model2.treeWeights
|
||||||
True
|
True
|
||||||
>>> model.trees
|
>>> model.trees
|
||||||
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
|
||||||
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
|
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
|
||||||
... ["label", "features"])
|
... ["label", "features"])
|
||||||
>>> model.evaluateEachIteration(validation, "squared")
|
>>> model.evaluateEachIteration(validation, "squared")
|
||||||
|
|
Loading…
Reference in a new issue