[SPARK-14412][ML][PYSPARK] Add StorageLevel params to ALS
`mllib` `ALS` supports `setIntermediateRDDStorageLevel` and `setFinalRDDStorageLevel`. This PR adds these as Params in `ml` `ALS`. They are put in group **expertParam** since few users will need them. ## How was this patch tested? New test cases in `ALSSuite` and `tests.py`. cc yanboliang jkbradley sethah rishabhbhardwaj Author: Nick Pentreath <nickp@za.ibm.com> Closes #12660 from MLnick/SPARK-14412-als-storage-params.
This commit is contained in:
parent
d7755cfd07
commit
90fa2c6e7f
|
@ -22,7 +22,7 @@ import java.io.IOException
|
||||||
|
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.reflect.ClassTag
|
import scala.reflect.ClassTag
|
||||||
import scala.util.Sorting
|
import scala.util.{Sorting, Try}
|
||||||
import scala.util.hashing.byteswap64
|
import scala.util.hashing.byteswap64
|
||||||
|
|
||||||
import com.github.fommil.netlib.BLAS.{getInstance => blas}
|
import com.github.fommil.netlib.BLAS.{getInstance => blas}
|
||||||
|
@ -153,12 +153,42 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
def getNonnegative: Boolean = $(nonnegative)
|
def getNonnegative: Boolean = $(nonnegative)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Param for StorageLevel for intermediate RDDs. Pass in a string representation of
|
||||||
|
* [[StorageLevel]]. Cannot be "NONE".
|
||||||
|
* Default: "MEMORY_AND_DISK".
|
||||||
|
*
|
||||||
|
* @group expertParam
|
||||||
|
*/
|
||||||
|
val intermediateRDDStorageLevel = new Param[String](this, "intermediateRDDStorageLevel",
|
||||||
|
"StorageLevel for intermediate RDDs. Cannot be 'NONE'. Default: 'MEMORY_AND_DISK'.",
|
||||||
|
(s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE")
|
||||||
|
|
||||||
|
/** @group expertGetParam */
|
||||||
|
def getIntermediateRDDStorageLevel: String = $(intermediateRDDStorageLevel)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Param for StorageLevel for ALS model factor RDDs. Pass in a string representation of
|
||||||
|
* [[StorageLevel]].
|
||||||
|
* Default: "MEMORY_AND_DISK".
|
||||||
|
*
|
||||||
|
* @group expertParam
|
||||||
|
*/
|
||||||
|
val finalRDDStorageLevel = new Param[String](this, "finalRDDStorageLevel",
|
||||||
|
"StorageLevel for ALS model factor RDDs. Default: 'MEMORY_AND_DISK'.",
|
||||||
|
(s: String) => Try(StorageLevel.fromString(s)).isSuccess)
|
||||||
|
|
||||||
|
/** @group expertGetParam */
|
||||||
|
def getFinalRDDStorageLevel: String = $(finalRDDStorageLevel)
|
||||||
|
|
||||||
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
|
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
|
||||||
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
|
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
|
||||||
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
|
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
|
||||||
|
intermediateRDDStorageLevel -> "MEMORY_AND_DISK", finalRDDStorageLevel -> "MEMORY_AND_DISK")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validates and transforms the input schema.
|
* Validates and transforms the input schema.
|
||||||
|
*
|
||||||
* @param schema input schema
|
* @param schema input schema
|
||||||
* @return output schema
|
* @return output schema
|
||||||
*/
|
*/
|
||||||
|
@ -374,8 +404,21 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
|
||||||
@Since("1.3.0")
|
@Since("1.3.0")
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
|
/** @group expertSetParam */
|
||||||
|
@Since("2.0.0")
|
||||||
|
def setIntermediateRDDStorageLevel(value: String): this.type = {
|
||||||
|
set(intermediateRDDStorageLevel, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @group expertSetParam */
|
||||||
|
@Since("2.0.0")
|
||||||
|
def setFinalRDDStorageLevel(value: String): this.type = {
|
||||||
|
set(finalRDDStorageLevel, value)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets both numUserBlocks and numItemBlocks to the specific value.
|
* Sets both numUserBlocks and numItemBlocks to the specific value.
|
||||||
|
*
|
||||||
* @group setParam
|
* @group setParam
|
||||||
*/
|
*/
|
||||||
@Since("1.3.0")
|
@Since("1.3.0")
|
||||||
|
@ -403,6 +446,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
|
||||||
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
|
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
|
||||||
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
|
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
|
||||||
alpha = $(alpha), nonnegative = $(nonnegative),
|
alpha = $(alpha), nonnegative = $(nonnegative),
|
||||||
|
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateRDDStorageLevel)),
|
||||||
|
finalRDDStorageLevel = StorageLevel.fromString($(finalRDDStorageLevel)),
|
||||||
checkpointInterval = $(checkpointInterval), seed = $(seed))
|
checkpointInterval = $(checkpointInterval), seed = $(seed))
|
||||||
val userDF = userFactors.toDF("id", "features")
|
val userDF = userFactors.toDF("id", "features")
|
||||||
val itemDF = itemFactors.toDF("id", "features")
|
val itemDF = itemFactors.toDF("id", "features")
|
||||||
|
@ -754,7 +799,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
* ratings are associated with srcIds(i).
|
* ratings are associated with srcIds(i).
|
||||||
* @param dstEncodedIndices encoded dst indices
|
* @param dstEncodedIndices encoded dst indices
|
||||||
* @param ratings ratings
|
* @param ratings ratings
|
||||||
*
|
|
||||||
* @see [[LocalIndexEncoder]]
|
* @see [[LocalIndexEncoder]]
|
||||||
*/
|
*/
|
||||||
private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
|
private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
|
||||||
|
@ -850,7 +894,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
* @param ratings raw ratings
|
* @param ratings raw ratings
|
||||||
* @param srcPart partitioner for src IDs
|
* @param srcPart partitioner for src IDs
|
||||||
* @param dstPart partitioner for dst IDs
|
* @param dstPart partitioner for dst IDs
|
||||||
*
|
|
||||||
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
|
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
|
||||||
*/
|
*/
|
||||||
private def partitionRatings[ID: ClassTag](
|
private def partitionRatings[ID: ClassTag](
|
||||||
|
@ -899,6 +942,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
|
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
|
||||||
|
*
|
||||||
* @param encoder encoder for dst indices
|
* @param encoder encoder for dst indices
|
||||||
*/
|
*/
|
||||||
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
|
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
|
||||||
|
@ -1099,6 +1143,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates in-blocks and out-blocks from rating blocks.
|
* Creates in-blocks and out-blocks from rating blocks.
|
||||||
|
*
|
||||||
* @param prefix prefix for in/out-block names
|
* @param prefix prefix for in/out-block names
|
||||||
* @param ratingBlocks rating blocks
|
* @param ratingBlocks rating blocks
|
||||||
* @param srcPart partitioner for src IDs
|
* @param srcPart partitioner for src IDs
|
||||||
|
@ -1187,7 +1232,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
* @param implicitPrefs whether to use implicit preference
|
* @param implicitPrefs whether to use implicit preference
|
||||||
* @param alpha the alpha constant in the implicit preference formulation
|
* @param alpha the alpha constant in the implicit preference formulation
|
||||||
* @param solver solver for least squares problems
|
* @param solver solver for least squares problems
|
||||||
*
|
|
||||||
* @return dst factors
|
* @return dst factors
|
||||||
*/
|
*/
|
||||||
private def computeFactors[ID](
|
private def computeFactors[ID](
|
||||||
|
|
|
@ -33,7 +33,9 @@ import org.apache.spark.mllib.linalg.Vectors
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
class ALSSuite
|
class ALSSuite
|
||||||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
|
||||||
|
@ -198,6 +200,7 @@ class ALSSuite
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates an explicit feedback dataset for testing ALS.
|
* Generates an explicit feedback dataset for testing ALS.
|
||||||
|
*
|
||||||
* @param numUsers number of users
|
* @param numUsers number of users
|
||||||
* @param numItems number of items
|
* @param numItems number of items
|
||||||
* @param rank rank
|
* @param rank rank
|
||||||
|
@ -238,6 +241,7 @@ class ALSSuite
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates an implicit feedback dataset for testing ALS.
|
* Generates an implicit feedback dataset for testing ALS.
|
||||||
|
*
|
||||||
* @param numUsers number of users
|
* @param numUsers number of users
|
||||||
* @param numItems number of items
|
* @param numItems number of items
|
||||||
* @param rank rank
|
* @param rank rank
|
||||||
|
@ -286,6 +290,7 @@ class ALSSuite
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
|
* Generates random user/item factors, with i.i.d. values drawn from U(a, b).
|
||||||
|
*
|
||||||
* @param size number of users/items
|
* @param size number of users/items
|
||||||
* @param rank number of features
|
* @param rank number of features
|
||||||
* @param random random number generator
|
* @param random random number generator
|
||||||
|
@ -311,6 +316,7 @@ class ALSSuite
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test ALS using the given training/test splits and parameters.
|
* Test ALS using the given training/test splits and parameters.
|
||||||
|
*
|
||||||
* @param training training dataset
|
* @param training training dataset
|
||||||
* @param test test dataset
|
* @param test test dataset
|
||||||
* @param rank rank of the matrix factorization
|
* @param rank rank of the matrix factorization
|
||||||
|
@ -514,6 +520,77 @@ class ALSSuite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ALSStorageSuite
|
||||||
|
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
|
||||||
|
|
||||||
|
test("invalid storage params") {
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
new ALS().setIntermediateRDDStorageLevel("foo")
|
||||||
|
}
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
new ALS().setIntermediateRDDStorageLevel("NONE")
|
||||||
|
}
|
||||||
|
intercept[IllegalArgumentException] {
|
||||||
|
new ALS().setFinalRDDStorageLevel("foo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("default and non-default storage params set correct RDD StorageLevels") {
|
||||||
|
val sqlContext = this.sqlContext
|
||||||
|
import sqlContext.implicits._
|
||||||
|
val data = Seq(
|
||||||
|
(0, 0, 1.0),
|
||||||
|
(0, 1, 2.0),
|
||||||
|
(1, 2, 3.0),
|
||||||
|
(1, 0, 2.0)
|
||||||
|
).toDF("user", "item", "rating")
|
||||||
|
val als = new ALS().setMaxIter(1).setRank(1)
|
||||||
|
// add listener to check intermediate RDD default storage levels
|
||||||
|
val defaultListener = new IntermediateRDDStorageListener
|
||||||
|
sc.addSparkListener(defaultListener)
|
||||||
|
val model = als.fit(data)
|
||||||
|
// check final factor RDD default storage levels
|
||||||
|
val defaultFactorRDDs = sc.getPersistentRDDs.collect {
|
||||||
|
case (id, rdd) if rdd.name == "userFactors" || rdd.name == "itemFactors" =>
|
||||||
|
rdd.name -> (id, rdd.getStorageLevel)
|
||||||
|
}.toMap
|
||||||
|
defaultFactorRDDs.foreach { case (_, (id, level)) =>
|
||||||
|
assert(level == StorageLevel.MEMORY_AND_DISK)
|
||||||
|
}
|
||||||
|
defaultListener.storageLevels.foreach(level => assert(level == StorageLevel.MEMORY_AND_DISK))
|
||||||
|
|
||||||
|
// add listener to check intermediate RDD non-default storage levels
|
||||||
|
val nonDefaultListener = new IntermediateRDDStorageListener
|
||||||
|
sc.addSparkListener(nonDefaultListener)
|
||||||
|
val nonDefaultModel = als
|
||||||
|
.setFinalRDDStorageLevel("MEMORY_ONLY")
|
||||||
|
.setIntermediateRDDStorageLevel("DISK_ONLY")
|
||||||
|
.fit(data)
|
||||||
|
// check final factor RDD non-default storage levels
|
||||||
|
val levels = sc.getPersistentRDDs.collect {
|
||||||
|
case (id, rdd) if rdd.name == "userFactors" && rdd.id != defaultFactorRDDs("userFactors")._1
|
||||||
|
|| rdd.name == "itemFactors" && rdd.id != defaultFactorRDDs("itemFactors")._1 =>
|
||||||
|
rdd.getStorageLevel
|
||||||
|
}
|
||||||
|
levels.foreach(level => assert(level == StorageLevel.MEMORY_ONLY))
|
||||||
|
nonDefaultListener.storageLevels.foreach(level => assert(level == StorageLevel.DISK_ONLY))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class IntermediateRDDStorageListener extends SparkListener {
|
||||||
|
|
||||||
|
val storageLevels: mutable.ArrayBuffer[StorageLevel] = mutable.ArrayBuffer()
|
||||||
|
|
||||||
|
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
|
||||||
|
val stageLevels = stageCompleted.stageInfo.rddInfos.collect {
|
||||||
|
case info if info.name.contains("Blocks") || info.name.contains("Factors-") =>
|
||||||
|
info.storageLevel
|
||||||
|
}
|
||||||
|
storageLevels ++= stageLevels
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
object ALSSuite {
|
object ALSSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -539,6 +616,8 @@ object ALSSuite {
|
||||||
"implicitPrefs" -> true,
|
"implicitPrefs" -> true,
|
||||||
"alpha" -> 0.9,
|
"alpha" -> 0.9,
|
||||||
"nonnegative" -> true,
|
"nonnegative" -> true,
|
||||||
"checkpointInterval" -> 20
|
"checkpointInterval" -> 20,
|
||||||
|
"intermediateRDDStorageLevel" -> "MEMORY_ONLY",
|
||||||
|
"finalRDDStorageLevel" -> "MEMORY_AND_DISK_SER"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,21 +119,35 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
|
||||||
nonnegative = Param(Params._dummy(), "nonnegative",
|
nonnegative = Param(Params._dummy(), "nonnegative",
|
||||||
"whether to use nonnegative constraint for least squares",
|
"whether to use nonnegative constraint for least squares",
|
||||||
typeConverter=TypeConverters.toBoolean)
|
typeConverter=TypeConverters.toBoolean)
|
||||||
|
intermediateRDDStorageLevel = Param(Params._dummy(), "intermediateRDDStorageLevel",
|
||||||
|
"StorageLevel for intermediate RDDs. Cannot be 'NONE'. " +
|
||||||
|
"Default: 'MEMORY_AND_DISK'.",
|
||||||
|
typeConverter=TypeConverters.toString)
|
||||||
|
finalRDDStorageLevel = Param(Params._dummy(), "finalRDDStorageLevel",
|
||||||
|
"StorageLevel for ALS model factor RDDs. " +
|
||||||
|
"Default: 'MEMORY_AND_DISK'.",
|
||||||
|
typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
||||||
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
||||||
ratingCol="rating", nonnegative=False, checkpointInterval=10):
|
ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
||||||
|
intermediateRDDStorageLevel="MEMORY_AND_DISK",
|
||||||
|
finalRDDStorageLevel="MEMORY_AND_DISK"):
|
||||||
"""
|
"""
|
||||||
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
|
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
|
||||||
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
|
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
|
||||||
ratingCol="rating", nonnegative=false, checkpointInterval=10)
|
ratingCol="rating", nonnegative=false, checkpointInterval=10, \
|
||||||
|
intermediateRDDStorageLevel="MEMORY_AND_DISK", \
|
||||||
|
finalRDDStorageLevel="MEMORY_AND_DISK")
|
||||||
"""
|
"""
|
||||||
super(ALS, self).__init__()
|
super(ALS, self).__init__()
|
||||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
|
||||||
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
||||||
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
||||||
ratingCol="rating", nonnegative=False, checkpointInterval=10)
|
ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
||||||
|
intermediateRDDStorageLevel="MEMORY_AND_DISK",
|
||||||
|
finalRDDStorageLevel="MEMORY_AND_DISK")
|
||||||
kwargs = self.__init__._input_kwargs
|
kwargs = self.__init__._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -141,11 +155,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
|
||||||
@since("1.4.0")
|
@since("1.4.0")
|
||||||
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
||||||
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
|
||||||
ratingCol="rating", nonnegative=False, checkpointInterval=10):
|
ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
||||||
|
intermediateRDDStorageLevel="MEMORY_AND_DISK",
|
||||||
|
finalRDDStorageLevel="MEMORY_AND_DISK"):
|
||||||
"""
|
"""
|
||||||
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
|
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
|
||||||
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
|
||||||
ratingCol="rating", nonnegative=False, checkpointInterval=10)
|
ratingCol="rating", nonnegative=False, checkpointInterval=10, \
|
||||||
|
intermediateRDDStorageLevel="MEMORY_AND_DISK", \
|
||||||
|
finalRDDStorageLevel="MEMORY_AND_DISK")
|
||||||
Sets params for ALS.
|
Sets params for ALS.
|
||||||
"""
|
"""
|
||||||
kwargs = self.setParams._input_kwargs
|
kwargs = self.setParams._input_kwargs
|
||||||
|
@ -297,6 +315,36 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
|
||||||
"""
|
"""
|
||||||
return self.getOrDefault(self.nonnegative)
|
return self.getOrDefault(self.nonnegative)
|
||||||
|
|
||||||
|
@since("2.0.0")
|
||||||
|
def setIntermediateRDDStorageLevel(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`intermediateRDDStorageLevel`.
|
||||||
|
"""
|
||||||
|
self._set(intermediateRDDStorageLevel=value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@since("2.0.0")
|
||||||
|
def getIntermediateRDDStorageLevel(self):
|
||||||
|
"""
|
||||||
|
Gets the value of intermediateRDDStorageLevel or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.intermediateRDDStorageLevel)
|
||||||
|
|
||||||
|
@since("2.0.0")
|
||||||
|
def setFinalRDDStorageLevel(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`finalRDDStorageLevel`.
|
||||||
|
"""
|
||||||
|
self._set(finalRDDStorageLevel=value)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@since("2.0.0")
|
||||||
|
def getFinalRDDStorageLevel(self):
|
||||||
|
"""
|
||||||
|
Gets the value of finalRDDStorageLevel or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.finalRDDStorageLevel)
|
||||||
|
|
||||||
|
|
||||||
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -50,12 +50,15 @@ from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvalu
|
||||||
from pyspark.ml.feature import *
|
from pyspark.ml.feature import *
|
||||||
from pyspark.ml.param import Param, Params, TypeConverters
|
from pyspark.ml.param import Param, Params, TypeConverters
|
||||||
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
|
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
|
||||||
|
from pyspark.ml.recommendation import ALS
|
||||||
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
|
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
|
||||||
from pyspark.ml.tuning import *
|
from pyspark.ml.tuning import *
|
||||||
from pyspark.ml.wrapper import JavaParams
|
from pyspark.ml.wrapper import JavaParams
|
||||||
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
|
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
|
||||||
from pyspark.sql import DataFrame, SQLContext, Row
|
from pyspark.sql import DataFrame, SQLContext, Row
|
||||||
from pyspark.sql.functions import rand
|
from pyspark.sql.functions import rand
|
||||||
|
from pyspark.sql.utils import IllegalArgumentException
|
||||||
|
from pyspark.storagelevel import *
|
||||||
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
|
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -999,6 +1002,30 @@ class HashingTFTest(PySparkTestCase):
|
||||||
": expected " + str(expected[i]) + ", got " + str(features[i]))
|
": expected " + str(expected[i]) + ", got " + str(features[i]))
|
||||||
|
|
||||||
|
|
||||||
|
class ALSTest(PySparkTestCase):
|
||||||
|
|
||||||
|
def test_storage_levels(self):
|
||||||
|
sqlContext = SQLContext(self.sc)
|
||||||
|
df = sqlContext.createDataFrame(
|
||||||
|
[(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
|
||||||
|
["user", "item", "rating"])
|
||||||
|
als = ALS().setMaxIter(1).setRank(1)
|
||||||
|
# test default params
|
||||||
|
als.fit(df)
|
||||||
|
self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
|
||||||
|
self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_AND_DISK")
|
||||||
|
self.assertEqual(als.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
|
||||||
|
self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "MEMORY_AND_DISK")
|
||||||
|
# test non-default params
|
||||||
|
als.setIntermediateRDDStorageLevel("MEMORY_ONLY_2")
|
||||||
|
als.setFinalRDDStorageLevel("DISK_ONLY")
|
||||||
|
als.fit(df)
|
||||||
|
self.assertEqual(als.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
|
||||||
|
self.assertEqual(als._java_obj.getIntermediateRDDStorageLevel(), "MEMORY_ONLY_2")
|
||||||
|
self.assertEqual(als.getFinalRDDStorageLevel(), "DISK_ONLY")
|
||||||
|
self.assertEqual(als._java_obj.getFinalRDDStorageLevel(), "DISK_ONLY")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from pyspark.ml.tests import *
|
from pyspark.ml.tests import *
|
||||||
if xmlrunner:
|
if xmlrunner:
|
||||||
|
|
Loading…
Reference in a new issue