From 20f85eca3d924aecd0fcf61cd516d9ac8e369dc1 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sat, 21 Dec 2013 14:54:13 -0500 Subject: [PATCH] Java stubs for ALSModel. --- .../spark/mllib/api/PythonMLLibAPI.scala | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala index 6472bf6367..4620cab175 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.recommendation._ import org.apache.spark.rdd.RDD import java.nio.ByteBuffer import java.nio.ByteOrder @@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable { ret.add(serializeDoubleMatrix(model.clusterCenters)) return ret } + + private def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + return new Rating(user, product, rating) + } + + /** + * Java stub for Python mllib ALSModel.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.train(ratings, rank, iterations, lambda, blocks) + } + + /** + * Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, + iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + } }