Java stubs for ALSModel.

This commit is contained in:
Tor Myklebust 2013-12-21 14:54:13 -05:00
parent 076fc16221
commit 20f85eca3d

View file

@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.recommendation._
import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable {
ret.add(serializeDoubleMatrix(model.clusterCenters))
return ret
}
private def unpackRating(ratingBytes: Array[Byte]): Rating = {
val bb = ByteBuffer.wrap(ratingBytes)
bb.order(ByteOrder.nativeOrder())
val user = bb.getInt()
val product = bb.getInt()
val rating = bb.getDouble()
return new Rating(user, product, rating)
}
/**
* Java stub for Python mllib ALSModel.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
* the Py4J documentation.
*/
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
return ALS.train(ratings, rank, iterations, lambda, blocks)
}
/**
* Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
}