[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang <yuhao.yang@intel.com> Closes #17336 from hhbyyh/fpmodelminconf.
This commit is contained in:
parent
a59759e6c0
commit
b28bbffbad
|
@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
|
|||
def setPredictionCol(value: String): this.type = set(predictionCol, value)
|
||||
|
||||
/**
|
||||
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
|
||||
* Cache minConfidence and associationRules to avoid redundant computation for association rules
|
||||
* during transform. The associationRules will only be re-computed when minConfidence changed.
|
||||
*/
|
||||
@transient private var _cachedMinConf: Double = Double.NaN
|
||||
|
||||
@transient private var _cachedRules: DataFrame = _
|
||||
|
||||
/**
|
||||
* Get association rules fitted using the minConfidence. Returns a dataframe
|
||||
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
|
||||
* "consequent" are Array[T] and "confidence" is Double.
|
||||
*/
|
||||
@Since("2.2.0")
|
||||
@transient lazy val associationRules: DataFrame = {
|
||||
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
|
||||
@transient def associationRules: DataFrame = {
|
||||
if ($(minConfidence) == _cachedMinConf) {
|
||||
_cachedRules
|
||||
} else {
|
||||
_cachedRules = AssociationRules
|
||||
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
|
||||
_cachedMinConf = $(minConfidence)
|
||||
_cachedRules
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.apache.spark.ml.fpm
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.util.DefaultReadWriteTest
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -85,24 +85,6 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
|||
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
|
||||
}
|
||||
|
||||
test("FPGrowth parameter check") {
|
||||
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
|
||||
val model = fpGrowth.fit(dataset)
|
||||
.setMinConfidence(0.5678)
|
||||
assert(fpGrowth.getMinSupport === 0.4567)
|
||||
assert(model.getMinConfidence === 0.5678)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
|
||||
assert(model.freqItemsets.sort("items").collect() ===
|
||||
model2.freqItemsets.sort("items").collect())
|
||||
}
|
||||
val fPGrowth = new FPGrowth()
|
||||
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
|
||||
FPGrowthSuite.allParamSettings, checkModelData)
|
||||
}
|
||||
|
||||
test("FPGrowth prediction should not contain duplicates") {
|
||||
// This should generate rule 1 -> 3, 2 -> 3
|
||||
val dataset = spark.createDataFrame(Seq(
|
||||
|
@ -117,6 +99,44 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
|
|||
|
||||
assert(prediction === Seq("3"))
|
||||
}
|
||||
|
||||
test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
|
||||
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
|
||||
val oldRulesNum = model.associationRules.count()
|
||||
val oldPredict = model.transform(dataset)
|
||||
|
||||
model.setMinConfidence(0.8765)
|
||||
assert(oldRulesNum > model.associationRules.count())
|
||||
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
|
||||
|
||||
// association rules should stay the same for same minConfidence
|
||||
model.setMinConfidence(0.1)
|
||||
assert(oldRulesNum === model.associationRules.count())
|
||||
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
|
||||
}
|
||||
|
||||
test("FPGrowth parameter check") {
|
||||
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
|
||||
val model = fpGrowth.fit(dataset)
|
||||
.setMinConfidence(0.5678)
|
||||
assert(fpGrowth.getMinSupport === 0.4567)
|
||||
assert(model.getMinConfidence === 0.5678)
|
||||
MLTestingUtils.checkCopy(model)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
|
||||
assert(model.freqItemsets.collect().toSet.equals(
|
||||
model2.freqItemsets.collect().toSet))
|
||||
assert(model.associationRules.collect().toSet.equals(
|
||||
model2.associationRules.collect().toSet))
|
||||
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
|
||||
model2.setMinConfidence(0.9).associationRules.collect().toSet))
|
||||
}
|
||||
val fPGrowth = new FPGrowth()
|
||||
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
|
||||
FPGrowthSuite.allParamSettings, checkModelData)
|
||||
}
|
||||
}
|
||||
|
||||
object FPGrowthSuite {
|
||||
|
|
Loading…
Reference in a new issue