[SPARK-15617][ML][DOC] Clarify that fMeasure in MulticlassMetrics is "micro" f1_score
## What changes were proposed in this pull request? 1, del precision,recall in `ml.MulticlassClassificationEvaluator` 2, update user guide for `mlllib.weightedFMeasure` ## How was this patch tested? local build Author: Ruifeng Zheng <ruifengz@foxmail.com> Closes #13390 from zhengruifeng/clarify_f1.
This commit is contained in:
parent
2ca563cc45
commit
2099e05f93
|
@ -140,7 +140,7 @@ definitions of positive and negative labels is straightforward.
|
||||||
#### Label based metrics
|
#### Label based metrics
|
||||||
|
|
||||||
Opposed to binary classification where there are only two possible labels, multiclass classification problems have many
|
Opposed to binary classification where there are only two possible labels, multiclass classification problems have many
|
||||||
possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all
|
possible labels and so the concept of label-based metrics is introduced. Accuracy measures precision across all
|
||||||
labels - the number of times any class was predicted correctly (true positives) normalized by the number of data
|
labels - the number of times any class was predicted correctly (true positives) normalized by the number of data
|
||||||
points. Precision by label considers only one class, and measures the number of time a specific label was predicted
|
points. Precision by label considers only one class, and measures the number of time a specific label was predicted
|
||||||
correctly normalized by the number of times that label appears in the output.
|
correctly normalized by the number of times that label appears in the output.
|
||||||
|
@ -182,20 +182,10 @@ $$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td>Overall Precision</td>
|
<td>Accuracy</td>
|
||||||
<td>$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
<td>$ACC = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
||||||
\mathbf{y}_i\right)$</td>
|
\mathbf{y}_i\right)$</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
|
||||||
<td>Overall Recall</td>
|
|
||||||
<td>$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
|
||||||
\mathbf{y}_i\right)$</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>Overall F1-measure</td>
|
|
||||||
<td>$F1 = 2 \cdot \left(\frac{PPV \cdot TPR}
|
|
||||||
{PPV + TPR}\right)$</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
<tr>
|
||||||
<td>Precision by label</td>
|
<td>Precision by label</td>
|
||||||
<td>$PPV(\ell) = \frac{TP}{TP + FP} =
|
<td>$PPV(\ell) = \frac{TP}{TP + FP} =
|
||||||
|
|
|
@ -39,16 +39,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
||||||
def this() = this(Identifiable.randomUID("mcEval"))
|
def this() = this(Identifiable.randomUID("mcEval"))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
|
* param for metric name in evaluation (supports `"f1"` (default), `"weightedPrecision"`,
|
||||||
* `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`)
|
* `"weightedRecall"`, `"accuracy"`)
|
||||||
* @group param
|
* @group param
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
val metricName: Param[String] = {
|
val metricName: Param[String] = {
|
||||||
val allowedParams = ParamValidators.inArray(Array("f1", "precision",
|
val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision",
|
||||||
"recall", "weightedPrecision", "weightedRecall", "accuracy"))
|
"weightedRecall", "accuracy"))
|
||||||
new Param(this, "metricName", "metric name in evaluation " +
|
new Param(this, "metricName", "metric name in evaluation " +
|
||||||
"(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams)
|
"(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
|
@ -82,8 +82,6 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
||||||
val metrics = new MulticlassMetrics(predictionAndLabels)
|
val metrics = new MulticlassMetrics(predictionAndLabels)
|
||||||
val metric = $(metricName) match {
|
val metric = $(metricName) match {
|
||||||
case "f1" => metrics.weightedFMeasure
|
case "f1" => metrics.weightedFMeasure
|
||||||
case "precision" => metrics.accuracy
|
|
||||||
case "recall" => metrics.accuracy
|
|
||||||
case "weightedPrecision" => metrics.weightedPrecision
|
case "weightedPrecision" => metrics.weightedPrecision
|
||||||
case "weightedRecall" => metrics.weightedRecall
|
case "weightedRecall" => metrics.weightedRecall
|
||||||
case "accuracy" => metrics.accuracy
|
case "accuracy" => metrics.accuracy
|
||||||
|
|
|
@ -33,7 +33,7 @@ class MulticlassClassificationEvaluatorSuite
|
||||||
val evaluator = new MulticlassClassificationEvaluator()
|
val evaluator = new MulticlassClassificationEvaluator()
|
||||||
.setPredictionCol("myPrediction")
|
.setPredictionCol("myPrediction")
|
||||||
.setLabelCol("myLabel")
|
.setLabelCol("myLabel")
|
||||||
.setMetricName("recall")
|
.setMetricName("accuracy")
|
||||||
testDefaultReadWrite(evaluator)
|
testDefaultReadWrite(evaluator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -258,9 +258,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
|
||||||
>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
|
>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
|
||||||
>>> evaluator.evaluate(dataset)
|
>>> evaluator.evaluate(dataset)
|
||||||
0.66...
|
0.66...
|
||||||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
|
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
|
||||||
0.66...
|
|
||||||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
|
|
||||||
0.66...
|
0.66...
|
||||||
|
|
||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
Loading…
Reference in a new issue