[SPARK-14412][ML][PYSPARK] Add StorageLevel params to ALS

`mllib` `ALS` supports `setIntermediateRDDStorageLevel` and `setFinalRDDStorageLevel`. This PR adds these as Params in `ml` `ALS`. They are put in group **expertParam** since few users will need them.

## How was this patch tested?

New test cases in `ALSSuite` and `tests.py`.

cc yanboliang jkbradley sethah rishabhbhardwaj

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #12660 from MLnick/SPARK-14412-als-storage-params.
This commit is contained in:
Nick Pentreath 2016-04-29 22:01:41 -07:00 committed by Xiangrui Meng
parent d7755cfd07
commit 90fa2c6e7f
4 changed files with 209 additions and 11 deletions

View file

@ -22,7 +22,7 @@ import java.io.IOException
import scala.collection.mutable import scala.collection.mutable
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.util.Sorting import scala.util.{Sorting, Try}
import scala.util.hashing.byteswap64 import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.BLAS.{getInstance => blas}
@ -153,12 +153,42 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group getParam */ /** @group getParam */
def getNonnegative: Boolean = $(nonnegative) def getNonnegative: Boolean = $(nonnegative)
/**
* Param for StorageLevel for intermediate RDDs. Pass in a string representation of
* [[StorageLevel]]. Cannot be "NONE".
* Default: "MEMORY_AND_DISK".
*
* @group expertParam
*/
val intermediateRDDStorageLevel = new Param[String](this, "intermediateRDDStorageLevel",
"StorageLevel for intermediate RDDs. Cannot be 'NONE'. Default: 'MEMORY_AND_DISK'.",
(s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE")
/** @group expertGetParam */
def getIntermediateRDDStorageLevel: String = $(intermediateRDDStorageLevel)
/**
* Param for StorageLevel for ALS model factor RDDs. Pass in a string representation of
* [[StorageLevel]].
* Default: "MEMORY_AND_DISK".
*
* @group expertParam
*/
val finalRDDStorageLevel = new Param[String](this, "finalRDDStorageLevel",
"StorageLevel for ALS model factor RDDs. Default: 'MEMORY_AND_DISK'.",
(s: String) => Try(StorageLevel.fromString(s)).isSuccess)
/** @group expertGetParam */
def getFinalRDDStorageLevel: String = $(finalRDDStorageLevel)
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateRDDStorageLevel -> "MEMORY_AND_DISK", finalRDDStorageLevel -> "MEMORY_AND_DISK")
/** /**
* Validates and transforms the input schema. * Validates and transforms the input schema.
*
* @param schema input schema * @param schema input schema
* @return output schema * @return output schema
*/ */
@ -374,8 +404,21 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("1.3.0") @Since("1.3.0")
def setSeed(value: Long): this.type = set(seed, value) def setSeed(value: Long): this.type = set(seed, value)
/** @group expertSetParam */
@Since("2.0.0")
def setIntermediateRDDStorageLevel(value: String): this.type = {
set(intermediateRDDStorageLevel, value)
}
/** @group expertSetParam */
@Since("2.0.0")
def setFinalRDDStorageLevel(value: String): this.type = {
set(finalRDDStorageLevel, value)
}
/** /**
* Sets both numUserBlocks and numItemBlocks to the specific value. * Sets both numUserBlocks and numItemBlocks to the specific value.
*
* @group setParam * @group setParam
*/ */
@Since("1.3.0") @Since("1.3.0")
@ -403,6 +446,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative), alpha = $(alpha), nonnegative = $(nonnegative),
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateRDDStorageLevel)),
finalRDDStorageLevel = StorageLevel.fromString($(finalRDDStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed)) checkpointInterval = $(checkpointInterval), seed = $(seed))
val userDF = userFactors.toDF("id", "features") val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features") val itemDF = itemFactors.toDF("id", "features")
@ -754,7 +799,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* ratings are associated with srcIds(i). * ratings are associated with srcIds(i).
* @param dstEncodedIndices encoded dst indices * @param dstEncodedIndices encoded dst indices
* @param ratings ratings * @param ratings ratings
*
* @see [[LocalIndexEncoder]] * @see [[LocalIndexEncoder]]
*/ */
private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag]( private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
@ -850,7 +894,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param ratings raw ratings * @param ratings raw ratings
* @param srcPart partitioner for src IDs * @param srcPart partitioner for src IDs
* @param dstPart partitioner for dst IDs * @param dstPart partitioner for dst IDs
*
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
*/ */
private def partitionRatings[ID: ClassTag]( private def partitionRatings[ID: ClassTag](
@ -899,6 +942,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** /**
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
*
* @param encoder encoder for dst indices * @param encoder encoder for dst indices
*/ */
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag]( private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
@ -1099,6 +1143,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** /**
* Creates in-blocks and out-blocks from rating blocks. * Creates in-blocks and out-blocks from rating blocks.
*
* @param prefix prefix for in/out-block names * @param prefix prefix for in/out-block names
* @param ratingBlocks rating blocks * @param ratingBlocks rating blocks
* @param srcPart partitioner for src IDs * @param srcPart partitioner for src IDs
@ -1187,7 +1232,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
* @param implicitPrefs whether to use implicit preference * @param implicitPrefs whether to use implicit preference
* @param alpha the alpha constant in the implicit preference formulation * @param alpha the alpha constant in the implicit preference formulation
* @param solver solver for least squares problems * @param solver solver for least squares problems
*
* @return dst factors * @return dst factors
*/ */
private def computeFactors[ID]( private def computeFactors[ID](

View file

@ -33,7 +33,9 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
class ALSSuite class ALSSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
@ -198,6 +200,7 @@ class ALSSuite
/** /**
* Generates an explicit feedback dataset for testing ALS. * Generates an explicit feedback dataset for testing ALS.
*
* @param numUsers number of users * @param numUsers number of users
* @param numItems number of items * @param numItems number of items
* @param rank rank * @param rank rank
@ -238,6 +241,7 @@ class ALSSuite
/** /**
* Generates an implicit feedback dataset for testing ALS. * Generates an implicit feedback dataset for testing ALS.
*
* @param numUsers number of users * @param numUsers number of users
* @param numItems number of items * @param numItems number of items
* @param rank rank * @param rank rank
@ -286,6 +290,7 @@ class ALSSuite
/** /**
* Generates random user/item factors, with i.i.d. values drawn from U(a, b). * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
*
* @param size number of users/items * @param size number of users/items
* @param rank number of features * @param rank number of features
* @param random random number generator * @param random random number generator
@ -311,6 +316,7 @@ class ALSSuite
/** /**
* Test ALS using the given training/test splits and parameters. * Test ALS using the given training/test splits and parameters.
*
* @param training training dataset * @param training training dataset
* @param test test dataset * @param test test dataset
* @param rank rank of the matrix factorization * @param rank rank of the matrix factorization
@ -514,6 +520,77 @@ class ALSSuite
} }
} }
class ALSStorageSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
test("invalid storage params") {
intercept[IllegalArgumentException] {
new ALS().setIntermediateRDDStorageLevel("foo")
}
intercept[IllegalArgumentException] {
new ALS().setIntermediateRDDStorageLevel("NONE")
}
intercept[IllegalArgumentException] {
new ALS().setFinalRDDStorageLevel("foo")
}
}
test("default and non-default storage params set correct RDD StorageLevels") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val data = Seq(
(0, 0, 1.0),
(0, 1, 2.0),
(1, 2, 3.0),
(1, 0, 2.0)
).toDF("user", "item", "rating")
val als = new ALS().setMaxIter(1).setRank(1)
// add listener to check intermediate RDD default storage levels
val defaultListener = new IntermediateRDDStorageListener
sc.addSparkListener(defaultListener)
val model = als.fit(data)
// check final factor RDD default storage levels
val defaultFactorRDDs = sc.getPersistentRDDs.collect {
case (id, rdd) if rdd.name == "userFactors" || rdd.name == "itemFactors" =>
rdd.name -> (id, rdd.getStorageLevel)
}.toMap
defaultFactorRDDs.foreach { case (_, (id, level)) =>
assert(level == StorageLevel.MEMORY_AND_DISK)
}
defaultListener.storageLevels.foreach(level => assert(level == StorageLevel.MEMORY_AND_DISK))
// add listener to check intermediate RDD non-default storage levels
val nonDefaultListener = new IntermediateRDDStorageListener
sc.addSparkListener(nonDefaultListener)
val nonDefaultModel = als
.setFinalRDDStorageLevel("MEMORY_ONLY")
.setIntermediateRDDStorageLevel("DISK_ONLY")
.fit(data)
// check final factor RDD non-default storage levels
val levels = sc.getPersistentRDDs.collect {
case (id, rdd) if rdd.name == "userFactors" && rdd.id != defaultFactorRDDs("userFactors")._1
|| rdd.name == "itemFactors" && rdd.id != defaultFactorRDDs("itemFactors")._1 =>
rdd.getStorageLevel
}
levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
nonDefaultListener.storageLevels.foreach(level => assert(level == StorageLevel.DISK_ONLY))
}
}
private class IntermediateRDDStorageListener extends SparkListener {
val storageLevels: mutable.ArrayBuffer[StorageLevel] = mutable.ArrayBuffer()
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val stageLevels = stageCompleted.stageInfo.rddInfos.collect {
case info if info.name.contains("Blocks") || info.name.contains("Factors-") =>
info.storageLevel
}
storageLevels ++= stageLevels
}
}
object ALSSuite { object ALSSuite {
/** /**
@ -539,6 +616,8 @@ object ALSSuite {
"implicitPrefs" -> true, "implicitPrefs" -> true,
"alpha" -> 0.9, "alpha" -> 0.9,
"nonnegative" -> true, "nonnegative" -> true,
"checkpointInterval" -> 20 "checkpointInterval" -> 20,
"intermediateRDDStorageLevel" -> "MEMORY_ONLY",
"finalRDDStorageLevel" -> "MEMORY_AND_DISK_SER"
) )
} }

View file

@ -119,21 +119,35 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
nonnegative = Param(Params._dummy(), "nonnegative", nonnegative = Param(Params._dummy(), "nonnegative",
"whether to use nonnegative constraint for least squares", "whether to use nonnegative constraint for least squares",
typeConverter=TypeConverters.toBoolean) typeConverter=TypeConverters.toBoolean)
intermediateRDDStorageLevel = Param(Params._dummy(), "intermediateRDDStorageLevel",
"StorageLevel for intermediate RDDs. Cannot be 'NONE'. " +
"Default: 'MEMORY_AND_DISK'.",
typeConverter=TypeConverters.toString)
finalRDDStorageLevel = Param(Params._dummy(), "finalRDDStorageLevel",
"StorageLevel for ALS model factor RDDs. " +
"Default: 'MEMORY_AND_DISK'.",
typeConverter=TypeConverters.toString)
@keyword_only @keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10): ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateRDDStorageLevel="MEMORY_AND_DISK",
finalRDDStorageLevel="MEMORY_AND_DISK"):
""" """
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=false, checkpointInterval=10) ratingCol="rating", nonnegative=false, checkpointInterval=10, \
intermediateRDDStorageLevel="MEMORY_AND_DISK", \
finalRDDStorageLevel="MEMORY_AND_DISK")
""" """
super(ALS, self).__init__() super(ALS, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10) ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateRDDStorageLevel="MEMORY_AND_DISK",
finalRDDStorageLevel="MEMORY_AND_DISK")
kwargs = self.__init__._input_kwargs kwargs = self.__init__._input_kwargs
self.setParams(**kwargs) self.setParams(**kwargs)
@ -141,11 +155,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
@since("1.4.0") @since("1.4.0")
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10): ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateRDDStorageLevel="MEMORY_AND_DISK",
finalRDDStorageLevel="MEMORY_AND_DISK"):
""" """
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=False, checkpointInterval=10) ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateRDDStorageLevel="MEMORY_AND_DISK", \
finalRDDStorageLevel="MEMORY_AND_DISK")
Sets params for ALS. Sets params for ALS.
""" """
kwargs = self.setParams._input_kwargs kwargs = self.setParams._input_kwargs
@ -297,6 +315,36 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
""" """
return self.getOrDefault(self.nonnegative) return self.getOrDefault(self.nonnegative)
@since("2.0.0")
def setIntermediateRDDStorageLevel(self, value):
"""
Sets the value of :py:attr:`intermediateRDDStorageLevel`.
"""
self._set(intermediateRDDStorageLevel=value)
return self
@since("2.0.0")
def getIntermediateRDDStorageLevel(self):
"""
Gets the value of intermediateRDDStorageLevel or its default value.
"""
return self.getOrDefault(self.intermediateRDDStorageLevel)
@since("2.0.0")
def setFinalRDDStorageLevel(self, value):
"""
Sets the value of :py:attr:`finalRDDStorageLevel`.
"""
self._set(finalRDDStorageLevel=value)
return self
@since("2.0.0")
def getFinalRDDStorageLevel(self):
"""
Gets the value of finalRDDStorageLevel or its default value.
"""
return self.getOrDefault(self.finalRDDStorageLevel)
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
""" """

View file

@ -50,12 +50,15 @@ from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvalu
from pyspark.ml.feature import * from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import * from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand from pyspark.sql.functions import rand
from pyspark.sql.utils import IllegalArgumentException
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@ -999,6 +1002,30 @@ class HashingTFTest(PySparkTestCase):
": expected " + str(expected[i]) + ", got " + str(features[i])) ": expected " + str(expected[i]) + ", got " + str(features[i]))
class ALSTest(PySparkTestCase):
def test_storage_levels(self):
sqlContext = SQLContext(self.sc)
df = sqlContext.createDataFrame(
[(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
["user", "item", "rating"])
als = ALS().setMaxIter(1).setRank(1)
# test default params
als.fit(df)
self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
self.assertEqual(als.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
# test non-default params
als.setIntermediateRDDStorageLevel("MEMORY_ONLY_2")
als.setFinalRDDStorageLevel("DISK_ONLY")
als.fit(df)
self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
self.assertEqual(als.getFinalRDDStorageLevel(), "DISK_ONLY")
self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "DISK_ONLY")
if __name__ == "__main__": if __name__ == "__main__":
from pyspark.ml.tests import * from pyspark.ml.tests import *
if xmlrunner: if xmlrunner: