[SPARK-12247][ML][DOC] Documentation for spark.ml's ALS and collaborative filtering in general

This documents the implementation of ALS in `spark.ml` with example code in scala, java and python.

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #10411 from BenFradet/SPARK-12247.
This commit is contained in:
BenFradet 2016-02-16 13:03:28 +00:00 committed by Sean Owen
parent 827ed1c067
commit 00c72d27bf
10 changed files with 431 additions and 298 deletions

View file

@ -1,100 +0,0 @@
0::Movie 0::Romance|Comedy
1::Movie 1::Action|Anime
2::Movie 2::Romance|Thriller
3::Movie 3::Action|Romance
4::Movie 4::Anime|Comedy
5::Movie 5::Action|Action
6::Movie 6::Action|Comedy
7::Movie 7::Anime|Comedy
8::Movie 8::Comedy|Action
9::Movie 9::Anime|Thriller
10::Movie 10::Action|Anime
11::Movie 11::Action|Anime
12::Movie 12::Anime|Comedy
13::Movie 13::Thriller|Action
14::Movie 14::Anime|Comedy
15::Movie 15::Comedy|Thriller
16::Movie 16::Anime|Romance
17::Movie 17::Thriller|Action
18::Movie 18::Action|Comedy
19::Movie 19::Anime|Romance
20::Movie 20::Action|Anime
21::Movie 21::Romance|Thriller
22::Movie 22::Romance|Romance
23::Movie 23::Comedy|Comedy
24::Movie 24::Anime|Action
25::Movie 25::Comedy|Comedy
26::Movie 26::Anime|Romance
27::Movie 27::Anime|Anime
28::Movie 28::Thriller|Anime
29::Movie 29::Anime|Romance
30::Movie 30::Thriller|Romance
31::Movie 31::Thriller|Romance
32::Movie 32::Comedy|Anime
33::Movie 33::Comedy|Comedy
34::Movie 34::Anime|Anime
35::Movie 35::Action|Thriller
36::Movie 36::Anime|Romance
37::Movie 37::Romance|Anime
38::Movie 38::Thriller|Romance
39::Movie 39::Romance|Comedy
40::Movie 40::Action|Anime
41::Movie 41::Comedy|Thriller
42::Movie 42::Comedy|Action
43::Movie 43::Thriller|Anime
44::Movie 44::Anime|Action
45::Movie 45::Comedy|Romance
46::Movie 46::Comedy|Action
47::Movie 47::Romance|Comedy
48::Movie 48::Action|Comedy
49::Movie 49::Romance|Romance
50::Movie 50::Comedy|Romance
51::Movie 51::Action|Action
52::Movie 52::Thriller|Action
53::Movie 53::Action|Action
54::Movie 54::Romance|Thriller
55::Movie 55::Anime|Romance
56::Movie 56::Comedy|Action
57::Movie 57::Action|Anime
58::Movie 58::Thriller|Romance
59::Movie 59::Thriller|Comedy
60::Movie 60::Anime|Comedy
61::Movie 61::Comedy|Action
62::Movie 62::Comedy|Romance
63::Movie 63::Romance|Thriller
64::Movie 64::Romance|Action
65::Movie 65::Anime|Romance
66::Movie 66::Comedy|Action
67::Movie 67::Thriller|Anime
68::Movie 68::Thriller|Romance
69::Movie 69::Action|Comedy
70::Movie 70::Thriller|Thriller
71::Movie 71::Action|Comedy
72::Movie 72::Thriller|Romance
73::Movie 73::Comedy|Action
74::Movie 74::Action|Action
75::Movie 75::Action|Action
76::Movie 76::Comedy|Comedy
77::Movie 77::Comedy|Comedy
78::Movie 78::Comedy|Comedy
79::Movie 79::Thriller|Thriller
80::Movie 80::Comedy|Anime
81::Movie 81::Comedy|Anime
82::Movie 82::Romance|Anime
83::Movie 83::Comedy|Thriller
84::Movie 84::Anime|Action
85::Movie 85::Thriller|Anime
86::Movie 86::Romance|Anime
87::Movie 87::Thriller|Thriller
88::Movie 88::Romance|Thriller
89::Movie 89::Action|Anime
90::Movie 90::Anime|Romance
91::Movie 91::Anime|Thriller
92::Movie 92::Action|Comedy
93::Movie 93::Romance|Thriller
94::Movie 94::Thriller|Comedy
95::Movie 95::Action|Action
96::Movie 96::Thriller|Romance
97::Movie 97::Thriller|Thriller
98::Movie 98::Thriller|Comedy
99::Movie 99::Thriller|Romance

View file

@ -6,5 +6,7 @@
url: ml-classification-regression.html
- text: Clustering
url: ml-clustering.html
- text: Collaborative filtering
url: ml-collaborative-filtering.html
- text: Advanced topics
url: ml-advanced.html

View file

@ -0,0 +1,148 @@
---
layout: global
title: Collaborative Filtering - spark.ml
displayTitle: Collaborative Filtering - spark.ml
---
* Table of contents
{:toc}
## Collaborative filtering
[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
is commonly used for recommender systems. These techniques aim to fill in the
missing entries of a user-item association matrix. `spark.ml` currently supports
model-based collaborative filtering, in which users and products are described
by a small set of latent factors that can be used to predict missing entries.
`spark.ml` uses the [alternating least squares
(ALS)](http://dl.acm.org/citation.cfm?id=1608614)
algorithm to learn these latent factors. The implementation in `spark.ml` has the
following parameters:
* *numBlocks* is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10).
* *rank* is the number of latent factors in the model (defaults to 10).
* *maxIter* is the maximum number of iterations to run (defaults to 10).
* *regParam* specifies the regularization parameter in ALS (defaults to 1.0).
* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for
*implicit feedback* data (defaults to `false` which means using *explicit feedback*).
* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the
*baseline* confidence in preference observations (defaults to 1.0).
* *nonnegative* specifies whether or not to use nonnegative constraints for least squares (defaults to `false`).
### Explicit vs. implicit feedback
The standard approach to matrix factorization based collaborative filtering treats
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
for example, users giving ratings to movies.
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken
from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
confidence in observed user preferences, rather than explicit ratings given to items. The model
then tries to find latent factors that can be used to predict the expected preference of a user for
an item.
### Scaling of the regularization parameter
We scale the regularization parameter `regParam` in solving each least squares problem by
the number of ratings the user generated in updating user factors,
or the number of ratings the product received in updating product factors.
This approach is named "ALS-WR" and discussed in the paper
"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
It makes `regParam` less dependent on the scale of the dataset, so we can apply the
best parameter learned from a sampled subset to the full dataset and expect similar performance.
## Examples
<div class="codetabs">
<div data-lang="scala" markdown="1">
In the following example, we load rating data from the
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
consisting of a user, a movie, a rating and a timestamp.
We then train an ALS model which assumes, by default, that the ratings are
explicit (`implicitPrefs` is `false`).
We evaluate the recommendation model by measuring the root-mean-square error of
rating prediction.
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.ml.recommendation.ALS)
for more details on the API.
{% include_example scala/org/apache/spark/examples/ml/ALSExample.scala %}
If the rating matrix is derived from another source of information (i.e. it is
inferred from other signals), you can set `implicitPrefs` to `true` to get
better results:
{% highlight scala %}
val als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setImplicitPrefs(true)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
In the following example, we load rating data from the
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
consisting of a user, a movie, a rating and a timestamp.
We then train an ALS model which assumes, by default, that the ratings are
explicit (`implicitPrefs` is `false`).
We evaluate the recommendation model by measuring the root-mean-square error of
rating prediction.
Refer to the [`ALS` Java docs](api/java/org/apache/spark/ml/recommendation/ALS.html)
for more details on the API.
{% include_example java/org/apache/spark/examples/ml/JavaALSExample.java %}
If the rating matrix is derived from another source of information (i.e. it is
inferred from other signals), you can set `implicitPrefs` to `true` to get
better results:
{% highlight java %}
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setImplicitPrefs(true)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
{% endhighlight %}
</div>
<div data-lang="python" markdown="1">
In the following example, we load rating data from the
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
consisting of a user, a movie, a rating and a timestamp.
We then train an ALS model which assumes, by default, that the ratings are
explicit (`implicitPrefs` is `False`).
We evaluate the recommendation model by measuring the root-mean-square error of
rating prediction.
Refer to the [`ALS` Python docs](api/python/pyspark.ml.html#pyspark.ml.recommendation.ALS)
for more details on the API.
{% include_example python/ml/als_example.py %}
If the rating matrix is derived from another source of information (i.e. it is
inferred from other signals), you can set `implicitPrefs` to `True` to get
better results:
{% highlight python %}
als = ALS(maxIter=5, regParam=0.01, implicitPrefs=True,
userCol="userId", itemCol="movieId", ratingCol="rating")
{% endhighlight %}
</div>
</div>

View file

@ -31,17 +31,18 @@ following parameters:
### Explicit vs. implicit feedback
The standard approach to matrix factorization based collaborative filtering treats
the entries in the user-item matrix as *explicit* preferences given by the user to the item.
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
for example, users giving ratings to movies.
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken
from
[Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
Essentially instead of trying to model the matrix of ratings directly, this approach treats the data
as a combination of binary preferences and *confidence values*. The ratings are then related to the
level of confidence in observed user preferences, rather than explicit ratings given to items. The
model then tries to find latent factors that can be used to predict the expected preference of a
user for an item.
from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
confidence in observed user preferences, rather than explicit ratings given to items. The model
then tries to find latent factors that can be used to predict the expected preference of a user for
an item.
### Scaling of the regularization parameter
@ -50,9 +51,8 @@ the number of ratings the user generated in updating user factors,
or the number of ratings the product received in updating product factors.
This approach is named "ALS-WR" and discussed in the paper
"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
It makes `lambda` less dependent on the scale of the dataset.
So we can apply the best parameter learned from a sampled subset to the full dataset
and expect similar performance.
It makes `lambda` less dependent on the scale of the dataset, so we can apply the
best parameter learned from a sampled subset to the full dataset and expect similar performance.
## Examples
@ -64,11 +64,11 @@ We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.rec
method which assumes ratings are explicit. We evaluate the
recommendation model by measuring the Mean Squared Error of rating prediction.
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API.
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for more details on the API.
{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %}
If the rating matrix is derived from another source of information (e.g., it is inferred from
If the rating matrix is derived from another source of information (i.e. it is inferred from
other signals), you can use the `trainImplicit` method to get better results.
{% highlight scala %}
@ -85,7 +85,7 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a
calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given below:
Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API.
Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for more details on the API.
{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
</div>
@ -99,7 +99,7 @@ Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.rec
{% include_example python/mllib/recommendation_example.py %}
If the rating matrix is derived from other source of information (i.e., it is inferred from other
If the rating matrix is derived from other source of information (i.e. it is inferred from other
signals), you can use the trainImplicit method to get better results.
{% highlight python %}

View file

@ -71,6 +71,7 @@ We list major functionality from both below, with links to detailed guides.
* [Extracting, transforming and selecting features](ml-features.html)
* [Classification and regression](ml-classification-regression.html)
* [Clustering](ml-clustering.html)
* [Collaborative filtering](ml-collaborative-filtering.html)
* [Advanced topics](ml-advanced.html)
Some techniques are not available yet in spark.ml, most notably dimensionality reduction

View file

@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
// $example on$
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataTypes;
// $example off$
public class JavaALSExample {
// $example on$
public static class Rating implements Serializable {
private int userId;
private int movieId;
private float rating;
private long timestamp;
public Rating() {}
public Rating(int userId, int movieId, float rating, long timestamp) {
this.userId = userId;
this.movieId = movieId;
this.rating = rating;
this.timestamp = timestamp;
}
public int getUserId() {
return userId;
}
public int getMovieId() {
return movieId;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
public static Rating parseRating(String str) {
String[] fields = str.split("::");
if (fields.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
int userId = Integer.parseInt(fields[0]);
int movieId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
long timestamp = Long.parseLong(fields[3]);
return new Rating(userId, movieId, rating, timestamp);
}
}
// $example off$
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaALSExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$
JavaRDD<Rating> ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt")
.map(new Function<String, Rating>() {
public Rating call(String str) {
return Rating.parseRating(str);
}
});
DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
DataFrame training = splits[0];
DataFrame test = splits[1];
// Build the recommendation model using ALS on the training data
ALS als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
DataFrame rawPredictions = model.transform(test);
DataFrame predictions = rawPredictions
.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// $example off$
jsc.stop();
}
}

View file

@ -0,0 +1,57 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import print_function
from pyspark import SparkContext
from pyspark.sql import SQLContext
# $example on$
import math
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row
# $example off$
if __name__ == "__main__":
sc = SparkContext(appName="ALSExample")
sqlContext = SQLContext(sc)
# $example on$
lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
parts = lines.map(lambda l: l.split("::"))
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
rating=float(p[2]), timestamp=long(p[3])))
ratings = sqlContext.createDataFrame(ratingsRDD)
(training, test) = ratings.randomSplit([0.8, 0.2])
# Build the recommendation model using ALS on the training data
als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating")
model = als.fit(training)
# Evaluate the model by computing the RMSE on the test data
rawPredictions = model.transform(test)
predictions = rawPredictions\
.withColumn("rating", rawPredictions.rating.cast("double"))\
.withColumn("prediction", rawPredictions.prediction.cast("double"))
evaluator =\
RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
# $example off$
sc.stop()

View file

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// scalastyle:off println
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
// $example off$
import org.apache.spark.sql.SQLContext
// $example on$
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
// $example off$
object ALSExample {
// $example on$
case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
object Rating {
def parseRating(str: String): Rating = {
val fields = str.split("::")
assert(fields.size == 4)
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
}
}
// $example off$
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("ALSExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// $example on$
val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
.map(Rating.parseRating)
.toDF()
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))
// Build the recommendation model using ALS on the training data
val als = new ALS()
.setMaxIter(5)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating")
val model = als.fit(training)
// Evaluate the model by computing the RMSE on the test data
val predictions = model.transform(test)
.withColumn("rating", col("rating").cast(DoubleType))
.withColumn("prediction", col("prediction").cast(DoubleType))
val evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction")
val rmse = evaluator.evaluate(predictions)
println(s"Root-mean-square error = $rmse")
// $example off$
sc.stop()
}
}
// scalastyle:on println

View file

@ -1,182 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// scalastyle:off println
package org.apache.spark.examples.ml
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.{Row, SQLContext}
/**
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
* Run with
* {{{
* bin/run-example ml.MovieLensALS
* }}}
*/
object MovieLensALS {
case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
object Rating {
def parseRating(str: String): Rating = {
val fields = str.split("::")
assert(fields.size == 4)
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
}
}
case class Movie(movieId: Int, title: String, genres: Seq[String])
object Movie {
def parseMovie(str: String): Movie = {
val fields = str.split("::")
assert(fields.size == 3)
Movie(fields(0).toInt, fields(1), fields(2).split("\\|"))
}
}
case class Params(
ratings: String = null,
movies: String = null,
maxIter: Int = 10,
regParam: Double = 0.1,
rank: Int = 10,
numBlocks: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
val parser = new OptionParser[Params]("MovieLensALS") {
head("MovieLensALS: an example app for ALS on MovieLens data.")
opt[String]("ratings")
.required()
.text("path to a MovieLens dataset of ratings")
.action((x, c) => c.copy(ratings = x))
opt[String]("movies")
.required()
.text("path to a MovieLens dataset of movies")
.action((x, c) => c.copy(movies = x))
opt[Int]("rank")
.text(s"rank, default: ${defaultParams.rank}")
.action((x, c) => c.copy(rank = x))
opt[Int]("maxIter")
.text(s"max number of iterations, default: ${defaultParams.maxIter}")
.action((x, c) => c.copy(maxIter = x))
opt[Double]("regParam")
.text(s"regularization parameter, default: ${defaultParams.regParam}")
.action((x, c) => c.copy(regParam = x))
opt[Int]("numBlocks")
.text(s"number of blocks, default: ${defaultParams.numBlocks}")
.action((x, c) => c.copy(numBlocks = x))
note(
"""
|Example command line to run this app:
|
| bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
| examples/target/scala-*/spark-examples-*.jar \
| --rank 10 --maxIter 15 --regParam 0.1 \
| --movies data/mllib/als/sample_movielens_movies.txt \
| --ratings data/mllib/als/sample_movielens_ratings.txt
""".stripMargin)
}
parser.parse(args, defaultParams).map { params =>
run(params)
} getOrElse {
System.exit(1)
}
}
def run(params: Params) {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
val numRatings = ratings.count()
val numUsers = ratings.map(_.userId).distinct().count()
val numMovies = ratings.map(_.movieId).distinct().count()
println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
val training = splits(0).cache()
val test = splits(1).cache()
val numTraining = training.count()
val numTest = test.count()
println(s"Training: $numTraining, test: $numTest.")
ratings.unpersist(blocking = false)
val als = new ALS()
.setUserCol("userId")
.setItemCol("movieId")
.setRank(params.rank)
.setMaxIter(params.maxIter)
.setRegParam(params.regParam)
.setNumBlocks(params.numBlocks)
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF()).cache()
// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
val mse = predictions.select("rating", "prediction").rdd
.flatMap { case Row(rating: Float, prediction: Float) =>
val err = rating.toDouble - prediction
val err2 = err * err
if (err2.isNaN) {
None
} else {
Some(err2)
}
}.mean()
val rmse = math.sqrt(mse)
println(s"Test RMSE = $rmse.")
// Inspect false positives.
// Note: We reference columns in 2 ways:
// (1) predictions("movieId") lets us specify the movieId column in the predictions
// DataFrame, rather than the movieId column in the movies DataFrame.
// (2) $"userId" specifies the userId column in the predictions DataFrame.
// We could also write predictions("userId") but do not have to since
// the movies DataFrame does not have a column "userId."
val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF()
val falsePositives = predictions.join(movies)
.where((predictions("movieId") === movies("movieId"))
&& ($"rating" <= 1) && ($"prediction" >= 4))
.select($"userId", predictions("movieId"), $"title", $"rating", $"prediction")
val numFalsePositives = falsePositives.count()
println(s"Found $numFalsePositives false positives")
if (numFalsePositives > 0) {
println(s"Example false positives:")
falsePositives.limit(100).collect().foreach(println)
}
sc.stop()
}
}
// scalastyle:on println

View file

@ -496,7 +496,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
/**
* Solves a nonnegative least squares problem with L2 regularizatin:
* Solves a nonnegative least squares problem with L2 regularization:
*
* min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^
* subject to x >= 0