[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform

## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-20003
I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform.
Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed .

I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern.

## How was this patch tested?

new unit test and I strength the unit test for model save/load to ensure the cache mechanism.

Author: Yuhao Yang <yuhao.yang@intel.com>

Closes #17336 from hhbyyh/fpmodelminconf.
This commit is contained in:
Yuhao Yang 2017-04-04 17:51:45 -07:00 committed by Joseph K. Bradley
parent a59759e6c0
commit b28bbffbad
2 changed files with 57 additions and 22 deletions

View file

@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/**
* Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
* Cache minConfidence and associationRules to avoid redundant computation for association rules
* during transform. The associationRules will only be re-computed when minConfidence changed.
*/
@transient private var _cachedMinConf: Double = Double.NaN
@transient private var _cachedRules: DataFrame = _
/**
* Get association rules fitted using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
*/
@Since("2.2.0")
@transient lazy val associationRules: DataFrame = {
AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
@transient def associationRules: DataFrame = {
if ($(minConfidence) == _cachedMinConf) {
_cachedRules
} else {
_cachedRules = AssociationRules
.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
_cachedMinConf = $(minConfidence)
_cachedRules
}
}
/**

View file

@ -17,7 +17,7 @@
package org.apache.spark.ml.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
@ -85,24 +85,6 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
}
test("FPGrowth parameter check") {
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
val model = fpGrowth.fit(dataset)
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
}
test("read/write") {
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
assert(model.freqItemsets.sort("items").collect() ===
model2.freqItemsets.sort("items").collect())
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}
test("FPGrowth prediction should not contain duplicates") {
// This should generate rule 1 -> 3, 2 -> 3
val dataset = spark.createDataFrame(Seq(
@ -117,6 +99,44 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(prediction === Seq("3"))
}
test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
val oldRulesNum = model.associationRules.count()
val oldPredict = model.transform(dataset)
model.setMinConfidence(0.8765)
assert(oldRulesNum > model.associationRules.count())
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
// association rules should stay the same for same minConfidence
model.setMinConfidence(0.1)
assert(oldRulesNum === model.associationRules.count())
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
}
test("FPGrowth parameter check") {
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
val model = fpGrowth.fit(dataset)
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
MLTestingUtils.checkCopy(model)
}
test("read/write") {
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
assert(model.freqItemsets.collect().toSet.equals(
model2.freqItemsets.collect().toSet))
assert(model.associationRules.collect().toSet.equals(
model2.associationRules.collect().toSet))
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
model2.setMinConfidence(0.9).associationRules.collect().toSet))
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}
}
object FPGrowthSuite {