[SPARK-30929][ML] ML, GraphX 3.0 QA: API: New Scala APIs, docs
### What changes were proposed in this pull request? Auditing new ML Scala APIs introduced in 3.0. Fix found issues. ### Why are the changes needed? ### Does this PR introduce any user-facing change? Yes. Some doc changes ### How was this patch tested? Existing tests Closes #27818 from huaxingao/spark-30929. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
ef51ff9dc8
commit
b6b0343e3e
|
@ -186,7 +186,7 @@ class FMClassifier @Since("3.0.0") (
|
|||
@Since("3.0.0")
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
override protected[spark] def train(
|
||||
override protected def train(
|
||||
dataset: Dataset[_]
|
||||
): FMClassificationModel = instrumented { instr =>
|
||||
|
||||
|
|
|
@ -34,12 +34,13 @@ import org.apache.spark.sql.types._
|
|||
*/
|
||||
@Since("3.0.0")
|
||||
@Experimental
|
||||
class MultilabelClassificationEvaluator (override val uid: String)
|
||||
class MultilabelClassificationEvaluator @Since("3.0.0") (@Since("3.0.0") override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol
|
||||
with DefaultParamsWritable {
|
||||
|
||||
import MultilabelClassificationEvaluator.supportedMetricNames
|
||||
|
||||
@Since("3.0.0")
|
||||
def this() = this(Identifiable.randomUID("mlcEval"))
|
||||
|
||||
/**
|
||||
|
@ -49,6 +50,7 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
* `"microF1Measure"`)
|
||||
* @group param
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
final val metricName: Param[String] = {
|
||||
val allowedParams = ParamValidators.inArray(supportedMetricNames)
|
||||
new Param(this, "metricName", "metric name in evaluation " +
|
||||
|
@ -56,13 +58,21 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
}
|
||||
|
||||
/** @group getParam */
|
||||
@Since("3.0.0")
|
||||
def getMetricName: String = $(metricName)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setMetricName(value: String): this.type = set(metricName, value)
|
||||
|
||||
setDefault(metricName -> "f1Measure")
|
||||
|
||||
/**
|
||||
* param for the class whose metric will be computed in `"precisionByLabel"`, `"recallByLabel"`,
|
||||
* `"f1MeasureByLabel"`.
|
||||
* @group param
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
final val metricLabel: DoubleParam = new DoubleParam(this, "metricLabel",
|
||||
"The class whose metric will be computed in " +
|
||||
s"${supportedMetricNames.filter(_.endsWith("ByLabel")).mkString("(", "|", ")")}. " +
|
||||
|
@ -70,6 +80,7 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
ParamValidators.gtEq(0.0))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("3.0.0")
|
||||
def getMetricLabel: Double = $(metricLabel)
|
||||
|
||||
/** @group setParam */
|
||||
|
@ -78,12 +89,14 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
setDefault(metricLabel -> 0.0)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||
|
||||
|
||||
@Since("3.0.0")
|
||||
override def evaluate(dataset: Dataset[_]): Double = {
|
||||
val schema = dataset.schema
|
||||
SchemaUtils.checkColumnTypes(schema, $(predictionCol),
|
||||
|
@ -113,6 +126,7 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
}
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
override def isLargerBetter: Boolean = {
|
||||
$(metricName) match {
|
||||
case "hammingLoss" => false
|
||||
|
@ -120,6 +134,7 @@ class MultilabelClassificationEvaluator (override val uid: String)
|
|||
}
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
override def copy(extra: ParamMap): MultilabelClassificationEvaluator = defaultCopy(extra)
|
||||
|
||||
@Since("3.0.0")
|
||||
|
@ -139,5 +154,6 @@ object MultilabelClassificationEvaluator
|
|||
"precisionByLabel", "recallByLabel", "f1MeasureByLabel",
|
||||
"microPrecision", "microRecall", "microF1Measure")
|
||||
|
||||
@Since("3.0.0")
|
||||
override def load(path: String): MultilabelClassificationEvaluator = super.load(path)
|
||||
}
|
||||
|
|
|
@ -33,11 +33,12 @@ import org.apache.spark.sql.types._
|
|||
*/
|
||||
@Experimental
|
||||
@Since("3.0.0")
|
||||
class RankingEvaluator (override val uid: String)
|
||||
class RankingEvaluator @Since("3.0.0") (@Since("3.0.0") override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
|
||||
|
||||
import RankingEvaluator.supportedMetricNames
|
||||
|
||||
@Since("3.0.0")
|
||||
def this() = this(Identifiable.randomUID("rankEval"))
|
||||
|
||||
/**
|
||||
|
@ -45,6 +46,7 @@ class RankingEvaluator (override val uid: String)
|
|||
* `"meanAveragePrecisionAtK"`, `"precisionAtK"`, `"ndcgAtK"`, `"recallAtK"`)
|
||||
* @group param
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
final val metricName: Param[String] = {
|
||||
val allowedParams = ParamValidators.inArray(supportedMetricNames)
|
||||
new Param(this, "metricName", "metric name in evaluation " +
|
||||
|
@ -52,9 +54,11 @@ class RankingEvaluator (override val uid: String)
|
|||
}
|
||||
|
||||
/** @group getParam */
|
||||
@Since("3.0.0")
|
||||
def getMetricName: String = $(metricName)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setMetricName(value: String): this.type = set(metricName, value)
|
||||
|
||||
setDefault(metricName -> "meanAveragePrecision")
|
||||
|
@ -64,6 +68,7 @@ class RankingEvaluator (override val uid: String)
|
|||
* `"ndcgAtK"`, `"recallAtK"`. Must be > 0. The default value is 10.
|
||||
* @group param
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
final val k = new IntParam(this, "k",
|
||||
"The ranking position value used in " +
|
||||
s"${supportedMetricNames.filter(_.endsWith("AtK")).mkString("(", "|", ")")} " +
|
||||
|
@ -71,20 +76,24 @@ class RankingEvaluator (override val uid: String)
|
|||
ParamValidators.gt(0))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("3.0.0")
|
||||
def getK: Int = $(k)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setK(value: Int): this.type = set(k, value)
|
||||
|
||||
setDefault(k -> 10)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||
|
||||
|
||||
@Since("3.0.0")
|
||||
override def evaluate(dataset: Dataset[_]): Double = {
|
||||
val schema = dataset.schema
|
||||
SchemaUtils.checkColumnTypes(schema, $(predictionCol),
|
||||
|
@ -107,8 +116,10 @@ class RankingEvaluator (override val uid: String)
|
|||
}
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
override def isLargerBetter: Boolean = true
|
||||
|
||||
@Since("3.0.0")
|
||||
override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra)
|
||||
|
||||
@Since("3.0.0")
|
||||
|
@ -124,5 +135,6 @@ object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator] {
|
|||
private val supportedMetricNames = Array("meanAveragePrecision",
|
||||
"meanAveragePrecisionAtK", "precisionAtK", "ndcgAtK", "recallAtK")
|
||||
|
||||
@Since("3.0.0")
|
||||
override def load(path: String): RankingEvaluator = super.load(path)
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ private[feature] trait RobustScalerParams extends Params with HasInputCol with H
|
|||
* Note that NaN values are ignored in the computation of medians and ranges.
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
class RobustScaler (override val uid: String)
|
||||
class RobustScaler @Since("3.0.0") (@Since("3.0.0") override val uid: String)
|
||||
extends Estimator[RobustScalerModel] with RobustScalerParams with DefaultParamsWritable {
|
||||
|
||||
import RobustScaler._
|
||||
|
@ -186,7 +186,7 @@ class RobustScaler (override val uid: String)
|
|||
object RobustScaler extends DefaultParamsReadable[RobustScaler] {
|
||||
|
||||
// compute QuantileSummaries for each feature
|
||||
private[spark] def computeSummaries(
|
||||
private[ml] def computeSummaries(
|
||||
vectors: RDD[Vector],
|
||||
numFeatures: Int,
|
||||
relativeError: Double): RDD[(Int, QuantileSummaries)] = {
|
||||
|
@ -229,9 +229,9 @@ object RobustScaler extends DefaultParamsReadable[RobustScaler] {
|
|||
*/
|
||||
@Since("3.0.0")
|
||||
class RobustScalerModel private[ml] (
|
||||
override val uid: String,
|
||||
val range: Vector,
|
||||
val median: Vector)
|
||||
@Since("3.0.0") override val uid: String,
|
||||
@Since("3.0.0") val range: Vector,
|
||||
@Since("3.0.0") val median: Vector)
|
||||
extends Model[RobustScalerModel] with RobustScalerParams with MLWritable {
|
||||
|
||||
import RobustScalerModel._
|
||||
|
|
|
@ -194,8 +194,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
}
|
||||
}
|
||||
|
||||
@Since("3.0.0")
|
||||
override def train(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr =>
|
||||
override protected def train(
|
||||
dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr =>
|
||||
val instances = extractAFTPoints(dataset)
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
|
|
@ -409,7 +409,7 @@ class FMRegressor @Since("3.0.0") (
|
|||
@Since("3.0.0")
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
override protected[spark] def train(
|
||||
override protected def train(
|
||||
dataset: Dataset[_]
|
||||
): FMRegressionModel = instrumented { instr =>
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
* (default = "")
|
||||
* @group param
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
final val leafCol: Param[String] =
|
||||
new Param[String](this, "leafCol", "Leaf indices column name. " +
|
||||
"Predicted leaf index of each instance in each tree by preorder")
|
||||
|
@ -139,9 +140,11 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
cacheNodeIds -> false, checkpointInterval -> 10)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
final def setLeafCol(value: String): this.type = set(leafCol, value)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("3.0.0")
|
||||
final def getLeafCol: String = $(leafCol)
|
||||
|
||||
/** @group getParam */
|
||||
|
|
Loading…
Reference in a new issue