[SPARK-31734][ML][PYSPARK] Add weight support in ClusteringEvaluator
### What changes were proposed in this pull request? Add weight support in ClusteringEvaluator ### Why are the changes needed? Currently, BinaryClassificationEvaluator, RegressionEvaluator, and MulticlassClassificationEvaluator support instance weight, but ClusteringEvaluator doesn't, so we will add instance weight support in ClusteringEvaluator. ### Does this PR introduce _any_ user-facing change? Yes. ClusteringEvaluator.setWeightCol ### How was this patch tested? add new unit test Closes #28553 from huaxingao/weight_evaluator. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
7f36310500
commit
d4007776f2
|
@ -19,10 +19,11 @@ package org.apache.spark.ml.evaluation
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
||||
/**
|
||||
* Evaluator for clustering results.
|
||||
|
@ -34,7 +35,8 @@ import org.apache.spark.sql.functions.col
|
|||
*/
|
||||
@Since("2.3.0")
|
||||
class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
|
||||
extends Evaluator with HasPredictionCol with HasFeaturesCol with DefaultParamsWritable {
|
||||
extends Evaluator with HasPredictionCol with HasFeaturesCol with HasWeightCol
|
||||
with DefaultParamsWritable {
|
||||
|
||||
@Since("2.3.0")
|
||||
def this() = this(Identifiable.randomUID("cluEval"))
|
||||
|
@ -53,6 +55,10 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
|
|||
@Since("2.3.0")
|
||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.1.0")
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
/**
|
||||
* param for metric name in evaluation
|
||||
* (supports `"silhouette"` (default))
|
||||
|
@ -116,12 +122,26 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
|
|||
*/
|
||||
@Since("3.1.0")
|
||||
def getMetrics(dataset: Dataset[_]): ClusteringMetrics = {
|
||||
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol))
|
||||
SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))
|
||||
val schema = dataset.schema
|
||||
SchemaUtils.validateVectorCompatibleColumn(schema, $(featuresCol))
|
||||
SchemaUtils.checkNumericType(schema, $(predictionCol))
|
||||
if (isDefined(weightCol)) {
|
||||
SchemaUtils.checkNumericType(schema, $(weightCol))
|
||||
}
|
||||
|
||||
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
|
||||
|
||||
val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
|
||||
val df = dataset.select(col($(predictionCol)),
|
||||
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata))
|
||||
val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
|
||||
dataset.select(col($(predictionCol)),
|
||||
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
|
||||
lit(1.0).as(weightColName))
|
||||
} else {
|
||||
dataset.select(col($(predictionCol)),
|
||||
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
|
||||
col(weightColName).cast(DoubleType))
|
||||
}
|
||||
|
||||
val metrics = new ClusteringMetrics(df)
|
||||
metrics.setDistanceMeasure($(distanceMeasure))
|
||||
metrics
|
||||
|
|
|
@ -47,9 +47,9 @@ class ClusteringMetrics private[spark](dataset: Dataset[_]) {
|
|||
val columns = dataset.columns.toSeq
|
||||
if (distanceMeasure.equalsIgnoreCase("squaredEuclidean")) {
|
||||
SquaredEuclideanSilhouette.computeSilhouetteScore(
|
||||
dataset, columns(0), columns(1))
|
||||
dataset, columns(0), columns(1), columns(2))
|
||||
} else {
|
||||
CosineSilhouette.computeSilhouetteScore(dataset, columns(0), columns(1))
|
||||
CosineSilhouette.computeSilhouetteScore(dataset, columns(0), columns(1), columns(2))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -63,9 +63,10 @@ private[evaluation] abstract class Silhouette {
|
|||
def pointSilhouetteCoefficient(
|
||||
clusterIds: Set[Double],
|
||||
pointClusterId: Double,
|
||||
pointClusterNumOfPoints: Long,
|
||||
weightSum: Double,
|
||||
weight: Double,
|
||||
averageDistanceToCluster: (Double) => Double): Double = {
|
||||
if (pointClusterNumOfPoints == 1) {
|
||||
if (weightSum == weight) {
|
||||
// Single-element clusters have silhouette 0
|
||||
0.0
|
||||
} else {
|
||||
|
@ -77,8 +78,8 @@ private[evaluation] abstract class Silhouette {
|
|||
val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min
|
||||
// adjustment for excluding the node itself from the computation of the average dissimilarity
|
||||
val currentClusterDissimilarity =
|
||||
averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints /
|
||||
(pointClusterNumOfPoints - 1)
|
||||
averageDistanceToCluster(pointClusterId) * weightSum /
|
||||
(weightSum - weight)
|
||||
if (currentClusterDissimilarity < neighboringClusterDissimilarity) {
|
||||
1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
|
||||
} else if (currentClusterDissimilarity > neighboringClusterDissimilarity) {
|
||||
|
@ -92,8 +93,8 @@ private[evaluation] abstract class Silhouette {
|
|||
/**
|
||||
* Compute the mean Silhouette values of all samples.
|
||||
*/
|
||||
def overallScore(df: DataFrame, scoreColumn: Column): Double = {
|
||||
df.select(avg(scoreColumn)).collect()(0).getDouble(0)
|
||||
def overallScore(df: DataFrame, scoreColumn: Column, weightColumn: Column): Double = {
|
||||
df.select(sum(scoreColumn * weightColumn) / sum(weightColumn)).collect()(0).getDouble(0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,7 +268,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
}
|
||||
}
|
||||
|
||||
case class ClusterStats(featureSum: Vector, squaredNormSum: Double, numOfPoints: Long)
|
||||
case class ClusterStats(featureSum: Vector, squaredNormSum: Double, weightSum: Double)
|
||||
|
||||
/**
|
||||
* The method takes the input dataset and computes the aggregated values
|
||||
|
@ -277,6 +278,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
* @param predictionCol The name of the column which contains the predicted cluster id
|
||||
* for the point.
|
||||
* @param featuresCol The name of the column which contains the feature vector of the point.
|
||||
* @param weightCol The name of the column which contains the instance weight.
|
||||
* @return A [[scala.collection.immutable.Map]] which associates each cluster id
|
||||
* to a [[ClusterStats]] object (which contains the precomputed values `N`,
|
||||
* `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
|
||||
|
@ -284,36 +286,39 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
def computeClusterStats(
|
||||
df: DataFrame,
|
||||
predictionCol: String,
|
||||
featuresCol: String): Map[Double, ClusterStats] = {
|
||||
featuresCol: String,
|
||||
weightCol: String): Map[Double, ClusterStats] = {
|
||||
val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
|
||||
val clustersStatsRDD = df.select(
|
||||
col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
|
||||
col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"), col(weightCol))
|
||||
.rdd
|
||||
.map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) }
|
||||
.aggregateByKey[(DenseVector, Double, Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
|
||||
.map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2), row.getDouble(3))) }
|
||||
.aggregateByKey
|
||||
[(DenseVector, Double, Double)]((Vectors.zeros(numFeatures).toDense, 0.0, 0.0))(
|
||||
seqOp = {
|
||||
case (
|
||||
(featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long),
|
||||
(features, squaredNorm)
|
||||
(featureSum: DenseVector, squaredNormSum: Double, weightSum: Double),
|
||||
(features, squaredNorm, weight)
|
||||
) =>
|
||||
BLAS.axpy(1.0, features, featureSum)
|
||||
(featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
|
||||
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
|
||||
BLAS.axpy(weight, features, featureSum)
|
||||
(featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight)
|
||||
},
|
||||
combOp = {
|
||||
case (
|
||||
(featureSum1, squaredNormSum1, numOfPoints1),
|
||||
(featureSum2, squaredNormSum2, numOfPoints2)
|
||||
(featureSum1, squaredNormSum1, weightSum1),
|
||||
(featureSum2, squaredNormSum2, weightSum2)
|
||||
) =>
|
||||
BLAS.axpy(1.0, featureSum2, featureSum1)
|
||||
(featureSum1, squaredNormSum1 + squaredNormSum2, numOfPoints1 + numOfPoints2)
|
||||
(featureSum1, squaredNormSum1 + squaredNormSum2, weightSum1 + weightSum2)
|
||||
}
|
||||
)
|
||||
|
||||
clustersStatsRDD
|
||||
.collectAsMap()
|
||||
.mapValues {
|
||||
case (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long) =>
|
||||
SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, numOfPoints)
|
||||
case (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double) =>
|
||||
SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, weightSum)
|
||||
}
|
||||
.toMap
|
||||
}
|
||||
|
@ -324,6 +329,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
* @param broadcastedClustersMap A map of the precomputed values for each cluster.
|
||||
* @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point.
|
||||
* @param clusterId The id of the cluster the current point belongs to.
|
||||
* @param weight The instance weight of the current point.
|
||||
* @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point.
|
||||
* @return The Silhouette for the point.
|
||||
*/
|
||||
|
@ -331,6 +337,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]],
|
||||
point: Vector,
|
||||
clusterId: Double,
|
||||
weight: Double,
|
||||
squaredNorm: Double): Double = {
|
||||
|
||||
def compute(targetClusterId: Double): Double = {
|
||||
|
@ -338,13 +345,14 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum)
|
||||
|
||||
squaredNorm +
|
||||
clusterStats.squaredNormSum / clusterStats.numOfPoints -
|
||||
2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints
|
||||
clusterStats.squaredNormSum / clusterStats.weightSum -
|
||||
2 * pointDotClusterFeaturesSum / clusterStats.weightSum
|
||||
}
|
||||
|
||||
pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
|
||||
clusterId,
|
||||
broadcastedClustersMap.value(clusterId).numOfPoints,
|
||||
broadcastedClustersMap.value(clusterId).weightSum,
|
||||
weight,
|
||||
compute)
|
||||
}
|
||||
|
||||
|
@ -355,12 +363,14 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
* @param predictionCol The name of the column which contains the predicted cluster id
|
||||
* for the point.
|
||||
* @param featuresCol The name of the column which contains the feature vector of the point.
|
||||
* @param weightCol The name of the column which contains instance weight.
|
||||
* @return The average of the Silhouette values of the clustered data.
|
||||
*/
|
||||
def computeSilhouetteScore(
|
||||
dataset: Dataset[_],
|
||||
predictionCol: String,
|
||||
featuresCol: String): Double = {
|
||||
featuresCol: String,
|
||||
weightCol: String): Double = {
|
||||
SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext)
|
||||
|
||||
val squaredNormUDF = udf {
|
||||
|
@ -370,7 +380,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
|
||||
// compute aggregate values for clusters needed by the algorithm
|
||||
val clustersStatsMap = SquaredEuclideanSilhouette
|
||||
.computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol)
|
||||
.computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol, weightCol)
|
||||
|
||||
// Silhouette is reasonable only when the number of clusters is greater then 1
|
||||
assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
|
||||
|
@ -378,12 +388,12 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
|
|||
val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
|
||||
|
||||
val computeSilhouetteCoefficientUDF = udf {
|
||||
computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
|
||||
computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double, _: Double)
|
||||
}
|
||||
|
||||
val silhouetteScore = overallScore(dfWithSquaredNorm,
|
||||
computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType),
|
||||
col("squaredNorm")))
|
||||
col(weightCol), col("squaredNorm")), col(weightCol))
|
||||
|
||||
bClustersStatsMap.destroy()
|
||||
|
||||
|
@ -472,30 +482,35 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
* about a cluster which are needed by the algorithm.
|
||||
*
|
||||
* @param df The DataFrame which contains the input data
|
||||
* @param featuresCol The name of the column which contains the feature vector of the point.
|
||||
* @param predictionCol The name of the column which contains the predicted cluster id
|
||||
* for the point.
|
||||
* @param weightCol The name of the column which contains the instance weight.
|
||||
* @return A [[scala.collection.immutable.Map]] which associates each cluster id to a
|
||||
* its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`).
|
||||
*/
|
||||
def computeClusterStats(
|
||||
df: DataFrame,
|
||||
featuresCol: String,
|
||||
predictionCol: String): Map[Double, (Vector, Long)] = {
|
||||
predictionCol: String,
|
||||
weightCol: String): Map[Double, (Vector, Double)] = {
|
||||
val numFeatures = MetadataUtils.getNumFeatures(df, featuresCol)
|
||||
val clustersStatsRDD = df.select(
|
||||
col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
|
||||
col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName), col(weightCol))
|
||||
.rdd
|
||||
.map { row => (row.getDouble(0), row.getAs[Vector](1)) }
|
||||
.aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))(
|
||||
.map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) }
|
||||
.aggregateByKey[(DenseVector, Double)]((Vectors.zeros(numFeatures).toDense, 0.0))(
|
||||
seqOp = {
|
||||
case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) =>
|
||||
BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum)
|
||||
(normalizedFeaturesSum, numOfPoints + 1)
|
||||
case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
|
||||
(normalizedFeatures, weight)) =>
|
||||
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
|
||||
BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum)
|
||||
(normalizedFeaturesSum, weightSum + weight)
|
||||
},
|
||||
combOp = {
|
||||
case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) =>
|
||||
case ((normalizedFeaturesSum1, weightSum1), (normalizedFeaturesSum2, weightSum2)) =>
|
||||
BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1)
|
||||
(normalizedFeaturesSum1, numOfPoints1 + numOfPoints2)
|
||||
(normalizedFeaturesSum1, weightSum1 + weightSum2)
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -511,11 +526,13 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
* @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the
|
||||
* normalized features of the current point.
|
||||
* @param clusterId The id of the cluster the current point belongs to.
|
||||
* @param weight The instance weight of the current point.
|
||||
*/
|
||||
def computeSilhouetteCoefficient(
|
||||
broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]],
|
||||
broadcastedClustersMap: Broadcast[Map[Double, (Vector, Double)]],
|
||||
normalizedFeatures: Vector,
|
||||
clusterId: Double): Double = {
|
||||
clusterId: Double,
|
||||
weight: Double): Double = {
|
||||
|
||||
def compute(targetClusterId: Double): Double = {
|
||||
val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId)
|
||||
|
@ -525,6 +542,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
|
||||
clusterId,
|
||||
broadcastedClustersMap.value(clusterId)._2,
|
||||
weight,
|
||||
compute)
|
||||
}
|
||||
|
||||
|
@ -535,12 +553,14 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
* @param predictionCol The name of the column which contains the predicted cluster id
|
||||
* for the point.
|
||||
* @param featuresCol The name of the column which contains the feature vector of the point.
|
||||
* @param weightCol The name of the column which contains the instance weight.
|
||||
* @return The average of the Silhouette values of the clustered data.
|
||||
*/
|
||||
def computeSilhouetteScore(
|
||||
dataset: Dataset[_],
|
||||
predictionCol: String,
|
||||
featuresCol: String): Double = {
|
||||
featuresCol: String,
|
||||
weightCol: String): Double = {
|
||||
val normalizeFeatureUDF = udf {
|
||||
features: Vector => {
|
||||
val norm = Vectors.norm(features, 2.0)
|
||||
|
@ -553,7 +573,7 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
|
||||
// compute aggregate values for clusters needed by the algorithm
|
||||
val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol,
|
||||
predictionCol)
|
||||
predictionCol, weightCol)
|
||||
|
||||
// Silhouette is reasonable only when the number of clusters is greater then 1
|
||||
assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
|
||||
|
@ -561,12 +581,12 @@ private[evaluation] object CosineSilhouette extends Silhouette {
|
|||
val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
|
||||
|
||||
val computeSilhouetteCoefficientUDF = udf {
|
||||
computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double)
|
||||
computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
|
||||
}
|
||||
|
||||
val silhouetteScore = overallScore(dfWithNormalizedFeatures,
|
||||
computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName),
|
||||
col(predictionCol).cast(DoubleType)))
|
||||
col(predictionCol).cast(DoubleType), col(weightCol)), col(weightCol))
|
||||
|
||||
bClustersStatsMap.destroy()
|
||||
|
||||
|
|
|
@ -19,12 +19,13 @@ package org.apache.spark.ml.evaluation
|
|||
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.ml.attribute.AttributeGroup
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.functions.lit
|
||||
|
||||
|
||||
class ClusteringEvaluatorSuite
|
||||
|
@ -161,4 +162,44 @@ class ClusteringEvaluatorSuite
|
|||
|
||||
assert(evaluator.evaluate(irisDataset) == silhouetteScoreCosin)
|
||||
}
|
||||
|
||||
test("test weight support") {
|
||||
Seq("squaredEuclidean", "cosine").foreach { distanceMeasure =>
|
||||
val evaluator1 = new ClusteringEvaluator()
|
||||
.setFeaturesCol("features")
|
||||
.setPredictionCol("label")
|
||||
.setDistanceMeasure(distanceMeasure)
|
||||
|
||||
val evaluator2 = new ClusteringEvaluator()
|
||||
.setFeaturesCol("features")
|
||||
.setPredictionCol("label")
|
||||
.setDistanceMeasure(distanceMeasure)
|
||||
.setWeightCol("weight")
|
||||
|
||||
Seq(0.25, 1.0, 10.0, 99.99).foreach { w =>
|
||||
var score1 = evaluator1.evaluate(irisDataset)
|
||||
var score2 = evaluator2.evaluate(irisDataset.withColumn("weight", lit(w)))
|
||||
assert(score1 ~== score2 relTol 1e-6)
|
||||
|
||||
score1 = evaluator1.evaluate(newIrisDataset)
|
||||
score2 = evaluator2.evaluate(newIrisDataset.withColumn("weight", lit(w)))
|
||||
assert(score1 ~== score2 relTol 1e-6)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("single-element clusters with weight") {
|
||||
val singleItemClusters = spark.createDataFrame(spark.sparkContext.parallelize(Array(
|
||||
(0.0, Vectors.dense(5.1, 3.5, 1.4, 0.2), 6.0),
|
||||
(1.0, Vectors.dense(7.0, 3.2, 4.7, 1.4), 0.25),
|
||||
(2.0, Vectors.dense(6.3, 3.3, 6.0, 2.5), 9.99)))).toDF("label", "features", "weight")
|
||||
Seq("squaredEuclidean", "cosine").foreach { distanceMeasure =>
|
||||
val evaluator = new ClusteringEvaluator()
|
||||
.setFeaturesCol("features")
|
||||
.setPredictionCol("label")
|
||||
.setDistanceMeasure(distanceMeasure)
|
||||
.setWeightCol("weight")
|
||||
assert(evaluator.evaluate(singleItemClusters) === 0.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -654,7 +654,7 @@ class MultilabelClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
||||
class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, HasWeightCol,
|
||||
JavaMLReadable, JavaMLWritable):
|
||||
"""
|
||||
Evaluator for Clustering results, which expects two input
|
||||
|
@ -677,6 +677,18 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
|||
ClusteringEvaluator...
|
||||
>>> evaluator.evaluate(dataset)
|
||||
0.9079...
|
||||
>>> featureAndPredictionsWithWeight = map(lambda x: (Vectors.dense(x[0]), x[1], x[2]),
|
||||
... [([0.0, 0.5], 0.0, 2.5), ([0.5, 0.0], 0.0, 2.5), ([10.0, 11.0], 1.0, 2.5),
|
||||
... ([10.5, 11.5], 1.0, 2.5), ([1.0, 1.0], 0.0, 2.5), ([8.0, 6.0], 1.0, 2.5)])
|
||||
>>> dataset = spark.createDataFrame(
|
||||
... featureAndPredictionsWithWeight, ["features", "prediction", "weight"])
|
||||
>>> evaluator = ClusteringEvaluator()
|
||||
>>> evaluator.setPredictionCol("prediction")
|
||||
ClusteringEvaluator...
|
||||
>>> evaluator.setWeightCol("weight")
|
||||
ClusteringEvaluator...
|
||||
>>> evaluator.evaluate(dataset)
|
||||
0.9079...
|
||||
>>> ce_path = temp_path + "/ce"
|
||||
>>> evaluator.save(ce_path)
|
||||
>>> evaluator2 = ClusteringEvaluator.load(ce_path)
|
||||
|
@ -694,10 +706,10 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
|||
|
||||
@keyword_only
|
||||
def __init__(self, predictionCol="prediction", featuresCol="features",
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean"):
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
|
||||
"""
|
||||
__init__(self, predictionCol="prediction", featuresCol="features", \
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean")
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
|
||||
"""
|
||||
super(ClusteringEvaluator, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
|
@ -709,10 +721,10 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
|||
@keyword_only
|
||||
@since("2.3.0")
|
||||
def setParams(self, predictionCol="prediction", featuresCol="features",
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean"):
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None):
|
||||
"""
|
||||
setParams(self, predictionCol="prediction", featuresCol="features", \
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean")
|
||||
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
|
||||
Sets params for clustering evaluator.
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
|
@ -758,6 +770,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
|
|||
"""
|
||||
return self._set(predictionCol=value)
|
||||
|
||||
@since("3.1.0")
|
||||
def setWeightCol(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`weightCol`.
|
||||
"""
|
||||
return self._set(weightCol=value)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class RankingEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
|
||||
|
|
Loading…
Reference in a new issue