[SPARK-9903] [MLLIB] skip local processing in PrefixSpan if there are no small prefixes

There exists a chance that the prefixes keep growing to the maximum pattern length. Then the final local processing step becomes unnecessary. feynmanliang

Author: Xiangrui Meng <meng@databricks.com>

Closes #8136 from mengxr/SPARK-9903.
This commit is contained in:
Xiangrui Meng 2015-08-12 20:44:40 -07:00
parent d2d5e7fe2d
commit d7053bea98

View file

@ -282,25 +282,30 @@ object PrefixSpan extends Logging {
largePrefixes = newLargePrefixes
}
// Switch to local processing.
val bcSmallPrefixes = sc.broadcast(smallPrefixes)
val distributedFreqPattern = postfixes.flatMap { postfix =>
bcSmallPrefixes.value.values.map { prefix =>
(prefix.id, postfix.project(prefix).compressed)
}.filter(_._2.nonEmpty)
}.groupByKey().flatMap { case (id, projPostfixes) =>
val prefix = bcSmallPrefixes.value(id)
val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length)
// TODO: We collect projected postfixes into memory. We should also compare the performance
// TODO: of keeping them on shuffle files.
localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) =>
(prefix.items ++ pattern, count)
var freqPatterns = sc.parallelize(localFreqPatterns, 1)
val numSmallPrefixes = smallPrefixes.size
logInfo(s"number of small prefixes for local processing: $numSmallPrefixes")
if (numSmallPrefixes > 0) {
// Switch to local processing.
val bcSmallPrefixes = sc.broadcast(smallPrefixes)
val distributedFreqPattern = postfixes.flatMap { postfix =>
bcSmallPrefixes.value.values.map { prefix =>
(prefix.id, postfix.project(prefix).compressed)
}.filter(_._2.nonEmpty)
}.groupByKey().flatMap { case (id, projPostfixes) =>
val prefix = bcSmallPrefixes.value(id)
val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length)
// TODO: We collect projected postfixes into memory. We should also compare the performance
// TODO: of keeping them on shuffle files.
localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) =>
(prefix.items ++ pattern, count)
}
}
// Union local frequent patterns and distributed ones.
freqPatterns = freqPatterns ++ distributedFreqPattern
}
// Union local frequent patterns and distributed ones.
val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern)
.persist(StorageLevel.MEMORY_AND_DISK)
freqPatterns
}