[SPARK-31803][ML] Make sure instance weight is not negative
### What changes were proposed in this pull request? In the algorithms that support instance weight, add checks to make sure instance weight is not negative. ### Why are the changes needed? instance weight has to be >= 0.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually tested Closes #28621 from huaxingao/weight_check. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
765105b6f1
commit
50492c0bd3
|
@ -19,6 +19,7 @@ package org.apache.spark.ml
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params
|
|||
val w = this match {
|
||||
case p: HasWeightCol =>
|
||||
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
|
||||
col($(p.weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.json4s.DefaultFormats
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
|
@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") (
|
|||
}
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") (
|
|||
import spark.implicits._
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") (
|
|||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
|
@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") (
|
|||
instr.logNumFeatures(numFeatures)
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") (
|
|||
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
|
|||
col($(rawPredictionCol)),
|
||||
col($(labelCol)).cast(DoubleType),
|
||||
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
|
||||
else col($(weightCol)).cast(DoubleType)).rdd.map {
|
||||
else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
|
||||
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
|
||||
(rawPrediction(1), label, weight)
|
||||
case Row(rawPrediction: Double, label: Double, weight: Double) =>
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
|
||||
import org.apache.spark.ml.util._
|
||||
|
@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
|
|||
} else {
|
||||
dataset.select(col($(predictionCol)),
|
||||
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
|
||||
col(weightColName).cast(DoubleType))
|
||||
checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
|
||||
}
|
||||
|
||||
val metrics = new ClusteringMetrics(df)
|
||||
|
|
|
@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
(featureSum: DenseVector, squaredNormSum: Double, weightSum: Double),
|
||||
(features, squaredNorm, weight)
|
||||
) =>
|
||||
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
|
||||
BLAS.axpy(weight, features, featureSum)
|
||||
(featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight)
|
||||
},
|
||||
|
@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
seqOp = {
|
||||
case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
|
||||
(normalizedFeatures, weight)) =>
|
||||
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
|
||||
BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum)
|
||||
(normalizedFeaturesSum, weightSum + weight)
|
||||
},
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
|||
SchemaUtils.checkNumericType(schema, $(labelCol))
|
||||
|
||||
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
|
||||
col($(weightCol)).cast(DoubleType)
|
||||
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
|
||||
} else {
|
||||
lit(1.0)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.evaluation
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
||||
|
@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
|
|||
|
||||
val predictionAndLabelsWithWeights = dataset
|
||||
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
|
||||
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
|
||||
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
|
||||
else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)))
|
||||
.rdd
|
||||
.map { case Row(prediction: Double, label: Double, weight: Double) =>
|
||||
(prediction, label, weight) }
|
||||
|
|
|
@ -71,4 +71,10 @@ object functions {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
private[ml] def checkNonNegativeWeight = udf {
|
||||
value: Double =>
|
||||
require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")
|
||||
value
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.attribute._
|
||||
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
|
||||
import org.apache.spark.ml.optim._
|
||||
import org.apache.spark.ml.param._
|
||||
|
@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
|
||||
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )
|
||||
|
||||
val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
|
||||
val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol)))
|
||||
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
|
||||
|
||||
val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.functions.checkNonNegativeWeight
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
|
|||
} else {
|
||||
col($(featuresCol))
|
||||
}
|
||||
val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0)
|
||||
val w =
|
||||
if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0)
|
||||
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
|
||||
case Row(label: Double, feature: Double, weight: Double) =>
|
||||
(label, feature, weight)
|
||||
case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue