Java stubs for ALSModel.
This commit is contained in:
parent
076fc16221
commit
20f85eca3d
|
@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD
|
|||
import org.apache.spark.mllib.regression._
|
||||
import org.apache.spark.mllib.classification._
|
||||
import org.apache.spark.mllib.clustering._
|
||||
import org.apache.spark.mllib.recommendation._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
@ -194,4 +195,37 @@ class PythonMLLibAPI extends Serializable {
|
|||
ret.add(serializeDoubleMatrix(model.clusterCenters))
|
||||
return ret
|
||||
}
|
||||
|
||||
private def unpackRating(ratingBytes: Array[Byte]): Rating = {
|
||||
val bb = ByteBuffer.wrap(ratingBytes)
|
||||
bb.order(ByteOrder.nativeOrder())
|
||||
val user = bb.getInt()
|
||||
val product = bb.getInt()
|
||||
val rating = bb.getDouble()
|
||||
return new Rating(user, product, rating)
|
||||
}
|
||||
|
||||
/**
|
||||
* Java stub for Python mllib ALSModel.train(). This stub returns a handle
|
||||
* to the Java object instead of the content of the Java object. Extra care
|
||||
* needs to be taken in the Python code to ensure it gets freed on exit; see
|
||||
* the Py4J documentation.
|
||||
*/
|
||||
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
|
||||
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
|
||||
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
|
||||
return ALS.train(ratings, rank, iterations, lambda, blocks)
|
||||
}
|
||||
|
||||
/**
|
||||
* Java stub for Python mllib ALSModel.trainImplicit(). This stub returns a
|
||||
* handle to the Java object instead of the content of the Java object.
|
||||
* Extra care needs to be taken in the Python code to ensure it gets freed on
|
||||
* exit; see the Py4J documentation.
|
||||
*/
|
||||
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
|
||||
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
|
||||
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
|
||||
return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue