[SPARK-15617][ML][DOC] Clarify that fMeasure in MulticlassMetrics is "micro" f1_score
## What changes were proposed in this pull request? 1, del precision,recall in `ml.MulticlassClassificationEvaluator` 2, update user guide for `mlllib.weightedFMeasure` ## How was this patch tested? local build Author: Ruifeng Zheng <ruifengz@foxmail.com> Closes #13390 from zhengruifeng/clarify_f1.
This commit is contained in:
parent
2ca563cc45
commit
2099e05f93
|
@ -140,7 +140,7 @@ definitions of positive and negative labels is straightforward.
|
|||
#### Label based metrics
|
||||
|
||||
Opposed to binary classification where there are only two possible labels, multiclass classification problems have many
|
||||
possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all
|
||||
possible labels and so the concept of label-based metrics is introduced. Accuracy measures precision across all
|
||||
labels - the number of times any class was predicted correctly (true positives) normalized by the number of data
|
||||
points. Precision by label considers only one class, and measures the number of time a specific label was predicted
|
||||
correctly normalized by the number of times that label appears in the output.
|
||||
|
@ -182,20 +182,10 @@ $$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.
|
|||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Overall Precision</td>
|
||||
<td>$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
||||
<td>Accuracy</td>
|
||||
<td>$ACC = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
||||
\mathbf{y}_i\right)$</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Overall Recall</td>
|
||||
<td>$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i -
|
||||
\mathbf{y}_i\right)$</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Overall F1-measure</td>
|
||||
<td>$F1 = 2 \cdot \left(\frac{PPV \cdot TPR}
|
||||
{PPV + TPR}\right)$</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Precision by label</td>
|
||||
<td>$PPV(\ell) = \frac{TP}{TP + FP} =
|
||||
|
|
|
@ -39,16 +39,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
|||
def this() = this(Identifiable.randomUID("mcEval"))
|
||||
|
||||
/**
|
||||
* param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
|
||||
* `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`)
|
||||
* param for metric name in evaluation (supports `"f1"` (default), `"weightedPrecision"`,
|
||||
* `"weightedRecall"`, `"accuracy"`)
|
||||
* @group param
|
||||
*/
|
||||
@Since("1.5.0")
|
||||
val metricName: Param[String] = {
|
||||
val allowedParams = ParamValidators.inArray(Array("f1", "precision",
|
||||
"recall", "weightedPrecision", "weightedRecall", "accuracy"))
|
||||
val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision",
|
||||
"weightedRecall", "accuracy"))
|
||||
new Param(this, "metricName", "metric name in evaluation " +
|
||||
"(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams)
|
||||
"(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams)
|
||||
}
|
||||
|
||||
/** @group getParam */
|
||||
|
@ -82,8 +82,6 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
|
|||
val metrics = new MulticlassMetrics(predictionAndLabels)
|
||||
val metric = $(metricName) match {
|
||||
case "f1" => metrics.weightedFMeasure
|
||||
case "precision" => metrics.accuracy
|
||||
case "recall" => metrics.accuracy
|
||||
case "weightedPrecision" => metrics.weightedPrecision
|
||||
case "weightedRecall" => metrics.weightedRecall
|
||||
case "accuracy" => metrics.accuracy
|
||||
|
|
|
@ -33,7 +33,7 @@ class MulticlassClassificationEvaluatorSuite
|
|||
val evaluator = new MulticlassClassificationEvaluator()
|
||||
.setPredictionCol("myPrediction")
|
||||
.setLabelCol("myLabel")
|
||||
.setMetricName("recall")
|
||||
.setMetricName("accuracy")
|
||||
testDefaultReadWrite(evaluator)
|
||||
}
|
||||
|
||||
|
|
|
@ -258,9 +258,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
|
|||
>>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
|
||||
>>> evaluator.evaluate(dataset)
|
||||
0.66...
|
||||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
|
||||
0.66...
|
||||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
|
||||
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
|
||||
0.66...
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
|
|
Loading…
Reference in a new issue