[SPARK-7212] [MLLIB] Add sequence learning flag
Support mining of ordered frequent item sequences. Author: Feynman Liang <fliang@databricks.com> Closes #6997 from feynmanliang/fp-sequence and squashes the following commits: 7c14e15 [Feynman Liang] Improve scalatests with R code and Seq 0d3e4b6 [Feynman Liang] Fix python test ce987cb [Feynman Liang] Backwards compatibility aux constructor 34ef8f2 [Feynman Liang] Fix failing test due to reverse orderering f04bd50 [Feynman Liang] Naming, add ordered to FreqItemsets, test ordering using Seq 648d4d4 [Feynman Liang] Test case for frequent item sequences 252a36a [Feynman Liang] Add sequence learning flag
This commit is contained in:
parent
00a9d22bd6
commit
25f574eb9a
|
@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel
|
|||
* :: Experimental ::
|
||||
*
|
||||
* Model trained by [[FPGrowth]], which holds frequent itemsets.
|
||||
* @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
|
||||
* @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]]
|
||||
* @tparam Item item type
|
||||
*/
|
||||
@Experimental
|
||||
|
@ -62,13 +62,14 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex
|
|||
@Experimental
|
||||
class FPGrowth private (
|
||||
private var minSupport: Double,
|
||||
private var numPartitions: Int) extends Logging with Serializable {
|
||||
private var numPartitions: Int,
|
||||
private var ordered: Boolean) extends Logging with Serializable {
|
||||
|
||||
/**
|
||||
* Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same
|
||||
* as the input data}.
|
||||
* as the input data, ordered: `false`}.
|
||||
*/
|
||||
def this() = this(0.3, -1)
|
||||
def this() = this(0.3, -1, false)
|
||||
|
||||
/**
|
||||
* Sets the minimal support level (default: `0.3`).
|
||||
|
@ -86,6 +87,15 @@ class FPGrowth private (
|
|||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine
|
||||
* itemsets).
|
||||
*/
|
||||
def setOrdered(ordered: Boolean): this.type = {
|
||||
this.ordered = ordered
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes an FP-Growth model that contains frequent itemsets.
|
||||
* @param data input data set, each element contains a transaction
|
||||
|
@ -155,7 +165,7 @@ class FPGrowth private (
|
|||
.flatMap { case (part, tree) =>
|
||||
tree.extract(minCount, x => partitioner.getPartition(x) == part)
|
||||
}.map { case (ranks, count) =>
|
||||
new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
|
||||
new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,9 +181,12 @@ class FPGrowth private (
|
|||
itemToRank: Map[Item, Int],
|
||||
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
|
||||
val output = mutable.Map.empty[Int, Array[Int]]
|
||||
// Filter the basket by frequent items pattern and sort their ranks.
|
||||
// Filter the basket by frequent items pattern
|
||||
val filtered = transaction.flatMap(itemToRank.get)
|
||||
ju.Arrays.sort(filtered)
|
||||
if (!this.ordered) {
|
||||
ju.Arrays.sort(filtered)
|
||||
}
|
||||
// Generate conditional transactions
|
||||
val n = filtered.length
|
||||
var i = n - 1
|
||||
while (i >= 0) {
|
||||
|
@ -198,9 +211,18 @@ object FPGrowth {
|
|||
* Frequent itemset.
|
||||
* @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
|
||||
* @param freq frequency
|
||||
* @param ordered indicates if items represents an itemset (false) or sequence (true)
|
||||
* @tparam Item item type
|
||||
*/
|
||||
class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
|
||||
class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean)
|
||||
extends Serializable {
|
||||
|
||||
/**
|
||||
* Auxillary constructor, assumes unordered by default.
|
||||
*/
|
||||
def this(items: Array[Item], freq: Long) {
|
||||
this(items, freq, false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns items in a Java List.
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
|||
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
|
||||
test("FP-Growth using String type") {
|
||||
test("FP-Growth frequent itemsets using String type") {
|
||||
val transactions = Seq(
|
||||
"r z h k p",
|
||||
"z y x w v u t s",
|
||||
|
@ -38,12 +38,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val model6 = fpg
|
||||
.setMinSupport(0.9)
|
||||
.setNumPartitions(1)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model6.freqItemsets.count() === 0)
|
||||
|
||||
val model3 = fpg
|
||||
.setMinSupport(0.5)
|
||||
.setNumPartitions(2)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
|
||||
(itemset.items.toSet, itemset.freq)
|
||||
|
@ -61,17 +63,59 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val model2 = fpg
|
||||
.setMinSupport(0.3)
|
||||
.setNumPartitions(4)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model2.freqItemsets.count() === 54)
|
||||
|
||||
val model1 = fpg
|
||||
.setMinSupport(0.1)
|
||||
.setNumPartitions(8)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model1.freqItemsets.count() === 625)
|
||||
}
|
||||
|
||||
test("FP-Growth using Int type") {
|
||||
test("FP-Growth frequent sequences using String type"){
|
||||
val transactions = Seq(
|
||||
"r z h k p",
|
||||
"z y x w v u t s",
|
||||
"s x o n r",
|
||||
"x z y m t s q e",
|
||||
"z",
|
||||
"x z y r q t p")
|
||||
.map(_.split(" "))
|
||||
val rdd = sc.parallelize(transactions, 2).cache()
|
||||
|
||||
val fpg = new FPGrowth()
|
||||
|
||||
val model1 = fpg
|
||||
.setMinSupport(0.5)
|
||||
.setNumPartitions(2)
|
||||
.setOrdered(true)
|
||||
.run(rdd)
|
||||
|
||||
/*
|
||||
Use the following R code to verify association rules using arulesSequences package.
|
||||
|
||||
data = read_baskets("path", info = c("sequenceID","eventID","SIZE"))
|
||||
freqItemSeq = cspade(data, parameter = list(support = 0.5))
|
||||
resSeq = as(freqItemSeq, "data.frame")
|
||||
resSeq$support = resSeq$support * length(transactions)
|
||||
names(resSeq)[names(resSeq) == "support"] = "freq"
|
||||
resSeq
|
||||
*/
|
||||
val expected = Set(
|
||||
(Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L),
|
||||
(Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L),
|
||||
(Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L)
|
||||
)
|
||||
val freqItemseqs1 = model1.freqItemsets.collect().map { itemset =>
|
||||
(itemset.items.toSeq, itemset.freq)
|
||||
}.toSet
|
||||
assert(freqItemseqs1 == expected)
|
||||
}
|
||||
|
||||
test("FP-Growth frequent itemsets using Int type") {
|
||||
val transactions = Seq(
|
||||
"1 2 3",
|
||||
"1 2 3 4",
|
||||
|
@ -88,12 +132,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val model6 = fpg
|
||||
.setMinSupport(0.9)
|
||||
.setNumPartitions(1)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model6.freqItemsets.count() === 0)
|
||||
|
||||
val model3 = fpg
|
||||
.setMinSupport(0.5)
|
||||
.setNumPartitions(2)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
|
||||
"frequent itemsets should use primitive arrays")
|
||||
|
@ -109,12 +155,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val model2 = fpg
|
||||
.setMinSupport(0.3)
|
||||
.setNumPartitions(4)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model2.freqItemsets.count() === 15)
|
||||
|
||||
val model1 = fpg
|
||||
.setMinSupport(0.1)
|
||||
.setNumPartitions(8)
|
||||
.setOrdered(false)
|
||||
.run(rdd)
|
||||
assert(model1.freqItemsets.count() === 65)
|
||||
}
|
||||
|
|
|
@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper):
|
|||
>>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
|
||||
>>> rdd = sc.parallelize(data, 2)
|
||||
>>> model = FPGrowth.train(rdd, 0.6, 2)
|
||||
>>> sorted(model.freqItemsets().collect())
|
||||
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
|
||||
>>> sorted(model.freqItemsets().collect(), key=lambda x: x.items)
|
||||
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ...
|
||||
"""
|
||||
|
||||
def freqItemsets(self):
|
||||
|
|
Loading…
Reference in a new issue