[SPARK-12247][ML][DOC] Documentation for spark.ml's ALS and collaborative filtering in general
This documents the implementation of ALS in `spark.ml` with example code in scala, java and python. Author: BenFradet <benjamin.fradet@gmail.com> Closes #10411 from BenFradet/SPARK-12247.
This commit is contained in:
parent
827ed1c067
commit
00c72d27bf
|
@ -1,100 +0,0 @@
|
|||
0::Movie 0::Romance|Comedy
|
||||
1::Movie 1::Action|Anime
|
||||
2::Movie 2::Romance|Thriller
|
||||
3::Movie 3::Action|Romance
|
||||
4::Movie 4::Anime|Comedy
|
||||
5::Movie 5::Action|Action
|
||||
6::Movie 6::Action|Comedy
|
||||
7::Movie 7::Anime|Comedy
|
||||
8::Movie 8::Comedy|Action
|
||||
9::Movie 9::Anime|Thriller
|
||||
10::Movie 10::Action|Anime
|
||||
11::Movie 11::Action|Anime
|
||||
12::Movie 12::Anime|Comedy
|
||||
13::Movie 13::Thriller|Action
|
||||
14::Movie 14::Anime|Comedy
|
||||
15::Movie 15::Comedy|Thriller
|
||||
16::Movie 16::Anime|Romance
|
||||
17::Movie 17::Thriller|Action
|
||||
18::Movie 18::Action|Comedy
|
||||
19::Movie 19::Anime|Romance
|
||||
20::Movie 20::Action|Anime
|
||||
21::Movie 21::Romance|Thriller
|
||||
22::Movie 22::Romance|Romance
|
||||
23::Movie 23::Comedy|Comedy
|
||||
24::Movie 24::Anime|Action
|
||||
25::Movie 25::Comedy|Comedy
|
||||
26::Movie 26::Anime|Romance
|
||||
27::Movie 27::Anime|Anime
|
||||
28::Movie 28::Thriller|Anime
|
||||
29::Movie 29::Anime|Romance
|
||||
30::Movie 30::Thriller|Romance
|
||||
31::Movie 31::Thriller|Romance
|
||||
32::Movie 32::Comedy|Anime
|
||||
33::Movie 33::Comedy|Comedy
|
||||
34::Movie 34::Anime|Anime
|
||||
35::Movie 35::Action|Thriller
|
||||
36::Movie 36::Anime|Romance
|
||||
37::Movie 37::Romance|Anime
|
||||
38::Movie 38::Thriller|Romance
|
||||
39::Movie 39::Romance|Comedy
|
||||
40::Movie 40::Action|Anime
|
||||
41::Movie 41::Comedy|Thriller
|
||||
42::Movie 42::Comedy|Action
|
||||
43::Movie 43::Thriller|Anime
|
||||
44::Movie 44::Anime|Action
|
||||
45::Movie 45::Comedy|Romance
|
||||
46::Movie 46::Comedy|Action
|
||||
47::Movie 47::Romance|Comedy
|
||||
48::Movie 48::Action|Comedy
|
||||
49::Movie 49::Romance|Romance
|
||||
50::Movie 50::Comedy|Romance
|
||||
51::Movie 51::Action|Action
|
||||
52::Movie 52::Thriller|Action
|
||||
53::Movie 53::Action|Action
|
||||
54::Movie 54::Romance|Thriller
|
||||
55::Movie 55::Anime|Romance
|
||||
56::Movie 56::Comedy|Action
|
||||
57::Movie 57::Action|Anime
|
||||
58::Movie 58::Thriller|Romance
|
||||
59::Movie 59::Thriller|Comedy
|
||||
60::Movie 60::Anime|Comedy
|
||||
61::Movie 61::Comedy|Action
|
||||
62::Movie 62::Comedy|Romance
|
||||
63::Movie 63::Romance|Thriller
|
||||
64::Movie 64::Romance|Action
|
||||
65::Movie 65::Anime|Romance
|
||||
66::Movie 66::Comedy|Action
|
||||
67::Movie 67::Thriller|Anime
|
||||
68::Movie 68::Thriller|Romance
|
||||
69::Movie 69::Action|Comedy
|
||||
70::Movie 70::Thriller|Thriller
|
||||
71::Movie 71::Action|Comedy
|
||||
72::Movie 72::Thriller|Romance
|
||||
73::Movie 73::Comedy|Action
|
||||
74::Movie 74::Action|Action
|
||||
75::Movie 75::Action|Action
|
||||
76::Movie 76::Comedy|Comedy
|
||||
77::Movie 77::Comedy|Comedy
|
||||
78::Movie 78::Comedy|Comedy
|
||||
79::Movie 79::Thriller|Thriller
|
||||
80::Movie 80::Comedy|Anime
|
||||
81::Movie 81::Comedy|Anime
|
||||
82::Movie 82::Romance|Anime
|
||||
83::Movie 83::Comedy|Thriller
|
||||
84::Movie 84::Anime|Action
|
||||
85::Movie 85::Thriller|Anime
|
||||
86::Movie 86::Romance|Anime
|
||||
87::Movie 87::Thriller|Thriller
|
||||
88::Movie 88::Romance|Thriller
|
||||
89::Movie 89::Action|Anime
|
||||
90::Movie 90::Anime|Romance
|
||||
91::Movie 91::Anime|Thriller
|
||||
92::Movie 92::Action|Comedy
|
||||
93::Movie 93::Romance|Thriller
|
||||
94::Movie 94::Thriller|Comedy
|
||||
95::Movie 95::Action|Action
|
||||
96::Movie 96::Thriller|Romance
|
||||
97::Movie 97::Thriller|Thriller
|
||||
98::Movie 98::Thriller|Comedy
|
||||
99::Movie 99::Thriller|Romance
|
|
@ -6,5 +6,7 @@
|
|||
url: ml-classification-regression.html
|
||||
- text: Clustering
|
||||
url: ml-clustering.html
|
||||
- text: Collaborative filtering
|
||||
url: ml-collaborative-filtering.html
|
||||
- text: Advanced topics
|
||||
url: ml-advanced.html
|
||||
|
|
148
docs/ml-collaborative-filtering.md
Normal file
148
docs/ml-collaborative-filtering.md
Normal file
|
@ -0,0 +1,148 @@
|
|||
---
|
||||
layout: global
|
||||
title: Collaborative Filtering - spark.ml
|
||||
displayTitle: Collaborative Filtering - spark.ml
|
||||
---
|
||||
|
||||
* Table of contents
|
||||
{:toc}
|
||||
|
||||
## Collaborative filtering
|
||||
|
||||
[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
|
||||
is commonly used for recommender systems. These techniques aim to fill in the
|
||||
missing entries of a user-item association matrix. `spark.ml` currently supports
|
||||
model-based collaborative filtering, in which users and products are described
|
||||
by a small set of latent factors that can be used to predict missing entries.
|
||||
`spark.ml` uses the [alternating least squares
|
||||
(ALS)](http://dl.acm.org/citation.cfm?id=1608614)
|
||||
algorithm to learn these latent factors. The implementation in `spark.ml` has the
|
||||
following parameters:
|
||||
|
||||
* *numBlocks* is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10).
|
||||
* *rank* is the number of latent factors in the model (defaults to 10).
|
||||
* *maxIter* is the maximum number of iterations to run (defaults to 10).
|
||||
* *regParam* specifies the regularization parameter in ALS (defaults to 1.0).
|
||||
* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for
|
||||
*implicit feedback* data (defaults to `false` which means using *explicit feedback*).
|
||||
* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the
|
||||
*baseline* confidence in preference observations (defaults to 1.0).
|
||||
* *nonnegative* specifies whether or not to use nonnegative constraints for least squares (defaults to `false`).
|
||||
|
||||
### Explicit vs. implicit feedback
|
||||
|
||||
The standard approach to matrix factorization based collaborative filtering treats
|
||||
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
|
||||
for example, users giving ratings to movies.
|
||||
|
||||
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
|
||||
clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken
|
||||
from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
|
||||
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
|
||||
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
|
||||
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
|
||||
confidence in observed user preferences, rather than explicit ratings given to items. The model
|
||||
then tries to find latent factors that can be used to predict the expected preference of a user for
|
||||
an item.
|
||||
|
||||
### Scaling of the regularization parameter
|
||||
|
||||
We scale the regularization parameter `regParam` in solving each least squares problem by
|
||||
the number of ratings the user generated in updating user factors,
|
||||
or the number of ratings the product received in updating product factors.
|
||||
This approach is named "ALS-WR" and discussed in the paper
|
||||
"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
|
||||
It makes `regParam` less dependent on the scale of the dataset, so we can apply the
|
||||
best parameter learned from a sampled subset to the full dataset and expect similar performance.
|
||||
|
||||
## Examples
|
||||
|
||||
<div class="codetabs">
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
||||
In the following example, we load rating data from the
|
||||
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
|
||||
consisting of a user, a movie, a rating and a timestamp.
|
||||
We then train an ALS model which assumes, by default, that the ratings are
|
||||
explicit (`implicitPrefs` is `false`).
|
||||
We evaluate the recommendation model by measuring the root-mean-square error of
|
||||
rating prediction.
|
||||
|
||||
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.ml.recommendation.ALS)
|
||||
for more details on the API.
|
||||
|
||||
{% include_example scala/org/apache/spark/examples/ml/ALSExample.scala %}
|
||||
|
||||
If the rating matrix is derived from another source of information (i.e. it is
|
||||
inferred from other signals), you can set `implicitPrefs` to `true` to get
|
||||
better results:
|
||||
|
||||
{% highlight scala %}
|
||||
val als = new ALS()
|
||||
.setMaxIter(5)
|
||||
.setRegParam(0.01)
|
||||
.setImplicitPrefs(true)
|
||||
.setUserCol("userId")
|
||||
.setItemCol("movieId")
|
||||
.setRatingCol("rating")
|
||||
{% endhighlight %}
|
||||
|
||||
</div>
|
||||
|
||||
<div data-lang="java" markdown="1">
|
||||
|
||||
In the following example, we load rating data from the
|
||||
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
|
||||
consisting of a user, a movie, a rating and a timestamp.
|
||||
We then train an ALS model which assumes, by default, that the ratings are
|
||||
explicit (`implicitPrefs` is `false`).
|
||||
We evaluate the recommendation model by measuring the root-mean-square error of
|
||||
rating prediction.
|
||||
|
||||
Refer to the [`ALS` Java docs](api/java/org/apache/spark/ml/recommendation/ALS.html)
|
||||
for more details on the API.
|
||||
|
||||
{% include_example java/org/apache/spark/examples/ml/JavaALSExample.java %}
|
||||
|
||||
If the rating matrix is derived from another source of information (i.e. it is
|
||||
inferred from other signals), you can set `implicitPrefs` to `true` to get
|
||||
better results:
|
||||
|
||||
{% highlight java %}
|
||||
ALS als = new ALS()
|
||||
.setMaxIter(5)
|
||||
.setRegParam(0.01)
|
||||
.setImplicitPrefs(true)
|
||||
.setUserCol("userId")
|
||||
.setItemCol("movieId")
|
||||
.setRatingCol("rating");
|
||||
{% endhighlight %}
|
||||
|
||||
</div>
|
||||
|
||||
<div data-lang="python" markdown="1">
|
||||
|
||||
In the following example, we load rating data from the
|
||||
[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row
|
||||
consisting of a user, a movie, a rating and a timestamp.
|
||||
We then train an ALS model which assumes, by default, that the ratings are
|
||||
explicit (`implicitPrefs` is `False`).
|
||||
We evaluate the recommendation model by measuring the root-mean-square error of
|
||||
rating prediction.
|
||||
|
||||
Refer to the [`ALS` Python docs](api/python/pyspark.ml.html#pyspark.ml.recommendation.ALS)
|
||||
for more details on the API.
|
||||
|
||||
{% include_example python/ml/als_example.py %}
|
||||
|
||||
If the rating matrix is derived from another source of information (i.e. it is
|
||||
inferred from other signals), you can set `implicitPrefs` to `True` to get
|
||||
better results:
|
||||
|
||||
{% highlight python %}
|
||||
als = ALS(maxIter=5, regParam=0.01, implicitPrefs=True,
|
||||
userCol="userId", itemCol="movieId", ratingCol="rating")
|
||||
{% endhighlight %}
|
||||
|
||||
</div>
|
||||
</div>
|
|
@ -31,17 +31,18 @@ following parameters:
|
|||
### Explicit vs. implicit feedback
|
||||
|
||||
The standard approach to matrix factorization based collaborative filtering treats
|
||||
the entries in the user-item matrix as *explicit* preferences given by the user to the item.
|
||||
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
|
||||
for example, users giving ratings to movies.
|
||||
|
||||
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
|
||||
clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken
|
||||
from
|
||||
[Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
|
||||
Essentially instead of trying to model the matrix of ratings directly, this approach treats the data
|
||||
as a combination of binary preferences and *confidence values*. The ratings are then related to the
|
||||
level of confidence in observed user preferences, rather than explicit ratings given to items. The
|
||||
model then tries to find latent factors that can be used to predict the expected preference of a
|
||||
user for an item.
|
||||
from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
|
||||
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
|
||||
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
|
||||
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
|
||||
confidence in observed user preferences, rather than explicit ratings given to items. The model
|
||||
then tries to find latent factors that can be used to predict the expected preference of a user for
|
||||
an item.
|
||||
|
||||
### Scaling of the regularization parameter
|
||||
|
||||
|
@ -50,9 +51,8 @@ the number of ratings the user generated in updating user factors,
|
|||
or the number of ratings the product received in updating product factors.
|
||||
This approach is named "ALS-WR" and discussed in the paper
|
||||
"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
|
||||
It makes `lambda` less dependent on the scale of the dataset.
|
||||
So we can apply the best parameter learned from a sampled subset to the full dataset
|
||||
and expect similar performance.
|
||||
It makes `lambda` less dependent on the scale of the dataset, so we can apply the
|
||||
best parameter learned from a sampled subset to the full dataset and expect similar performance.
|
||||
|
||||
## Examples
|
||||
|
||||
|
@ -64,11 +64,11 @@ We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.rec
|
|||
method which assumes ratings are explicit. We evaluate the
|
||||
recommendation model by measuring the Mean Squared Error of rating prediction.
|
||||
|
||||
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API.
|
||||
Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for more details on the API.
|
||||
|
||||
{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %}
|
||||
|
||||
If the rating matrix is derived from another source of information (e.g., it is inferred from
|
||||
If the rating matrix is derived from another source of information (i.e. it is inferred from
|
||||
other signals), you can use the `trainImplicit` method to get better results.
|
||||
|
||||
{% highlight scala %}
|
||||
|
@ -85,7 +85,7 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a
|
|||
calling `.rdd()` on your `JavaRDD` object. A self-contained application example
|
||||
that is equivalent to the provided example in Scala is given below:
|
||||
|
||||
Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API.
|
||||
Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for more details on the API.
|
||||
|
||||
{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
|
||||
</div>
|
||||
|
@ -99,7 +99,7 @@ Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.rec
|
|||
|
||||
{% include_example python/mllib/recommendation_example.py %}
|
||||
|
||||
If the rating matrix is derived from other source of information (i.e., it is inferred from other
|
||||
If the rating matrix is derived from other source of information (i.e. it is inferred from other
|
||||
signals), you can use the trainImplicit method to get better results.
|
||||
|
||||
{% highlight python %}
|
||||
|
|
|
@ -71,6 +71,7 @@ We list major functionality from both below, with links to detailed guides.
|
|||
* [Extracting, transforming and selecting features](ml-features.html)
|
||||
* [Classification and regression](ml-classification-regression.html)
|
||||
* [Clustering](ml-clustering.html)
|
||||
* [Collaborative filtering](ml-collaborative-filtering.html)
|
||||
* [Advanced topics](ml-advanced.html)
|
||||
|
||||
Some techniques are not available yet in spark.ml, most notably dimensionality reduction
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml;
|
||||
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
|
||||
// $example on$
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator;
|
||||
import org.apache.spark.ml.recommendation.ALS;
|
||||
import org.apache.spark.ml.recommendation.ALSModel;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
// $example off$
|
||||
|
||||
public class JavaALSExample {
|
||||
|
||||
// $example on$
|
||||
public static class Rating implements Serializable {
|
||||
private int userId;
|
||||
private int movieId;
|
||||
private float rating;
|
||||
private long timestamp;
|
||||
|
||||
public Rating() {}
|
||||
|
||||
public Rating(int userId, int movieId, float rating, long timestamp) {
|
||||
this.userId = userId;
|
||||
this.movieId = movieId;
|
||||
this.rating = rating;
|
||||
this.timestamp = timestamp;
|
||||
}
|
||||
|
||||
public int getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public int getMovieId() {
|
||||
return movieId;
|
||||
}
|
||||
|
||||
public float getRating() {
|
||||
return rating;
|
||||
}
|
||||
|
||||
public long getTimestamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
public static Rating parseRating(String str) {
|
||||
String[] fields = str.split("::");
|
||||
if (fields.length != 4) {
|
||||
throw new IllegalArgumentException("Each line must contain 4 fields");
|
||||
}
|
||||
int userId = Integer.parseInt(fields[0]);
|
||||
int movieId = Integer.parseInt(fields[1]);
|
||||
float rating = Float.parseFloat(fields[2]);
|
||||
long timestamp = Long.parseLong(fields[3]);
|
||||
return new Rating(userId, movieId, rating, timestamp);
|
||||
}
|
||||
}
|
||||
// $example off$
|
||||
|
||||
public static void main(String[] args) {
|
||||
SparkConf conf = new SparkConf().setAppName("JavaALSExample");
|
||||
JavaSparkContext jsc = new JavaSparkContext(conf);
|
||||
SQLContext sqlContext = new SQLContext(jsc);
|
||||
|
||||
// $example on$
|
||||
JavaRDD<Rating> ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt")
|
||||
.map(new Function<String, Rating>() {
|
||||
public Rating call(String str) {
|
||||
return Rating.parseRating(str);
|
||||
}
|
||||
});
|
||||
DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
|
||||
DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
|
||||
DataFrame training = splits[0];
|
||||
DataFrame test = splits[1];
|
||||
|
||||
// Build the recommendation model using ALS on the training data
|
||||
ALS als = new ALS()
|
||||
.setMaxIter(5)
|
||||
.setRegParam(0.01)
|
||||
.setUserCol("userId")
|
||||
.setItemCol("movieId")
|
||||
.setRatingCol("rating");
|
||||
ALSModel model = als.fit(training);
|
||||
|
||||
// Evaluate the model by computing the RMSE on the test data
|
||||
DataFrame rawPredictions = model.transform(test);
|
||||
DataFrame predictions = rawPredictions
|
||||
.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
|
||||
.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
|
||||
|
||||
RegressionEvaluator evaluator = new RegressionEvaluator()
|
||||
.setMetricName("rmse")
|
||||
.setLabelCol("rating")
|
||||
.setPredictionCol("prediction");
|
||||
Double rmse = evaluator.evaluate(predictions);
|
||||
System.out.println("Root-mean-square error = " + rmse);
|
||||
// $example off$
|
||||
jsc.stop();
|
||||
}
|
||||
}
|
57
examples/src/main/python/ml/als_example.py
Normal file
57
examples/src/main/python/ml/als_example.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.sql import SQLContext
|
||||
|
||||
# $example on$
|
||||
import math
|
||||
|
||||
from pyspark.ml.evaluation import RegressionEvaluator
|
||||
from pyspark.ml.recommendation import ALS
|
||||
from pyspark.sql import Row
|
||||
# $example off$
|
||||
|
||||
if __name__ == "__main__":
|
||||
sc = SparkContext(appName="ALSExample")
|
||||
sqlContext = SQLContext(sc)
|
||||
|
||||
# $example on$
|
||||
lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
|
||||
parts = lines.map(lambda l: l.split("::"))
|
||||
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
|
||||
rating=float(p[2]), timestamp=long(p[3])))
|
||||
ratings = sqlContext.createDataFrame(ratingsRDD)
|
||||
(training, test) = ratings.randomSplit([0.8, 0.2])
|
||||
|
||||
# Build the recommendation model using ALS on the training data
|
||||
als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating")
|
||||
model = als.fit(training)
|
||||
|
||||
# Evaluate the model by computing the RMSE on the test data
|
||||
rawPredictions = model.transform(test)
|
||||
predictions = rawPredictions\
|
||||
.withColumn("rating", rawPredictions.rating.cast("double"))\
|
||||
.withColumn("prediction", rawPredictions.prediction.cast("double"))
|
||||
evaluator =\
|
||||
RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
|
||||
rmse = evaluator.evaluate(predictions)
|
||||
print("Root-mean-square error = " + str(rmse))
|
||||
# $example off$
|
||||
sc.stop()
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// scalastyle:off println
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
// $example on$
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.recommendation.ALS
|
||||
// $example off$
|
||||
import org.apache.spark.sql.SQLContext
|
||||
// $example on$
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
// $example off$
|
||||
|
||||
object ALSExample {
|
||||
|
||||
// $example on$
|
||||
case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
|
||||
object Rating {
|
||||
def parseRating(str: String): Rating = {
|
||||
val fields = str.split("::")
|
||||
assert(fields.size == 4)
|
||||
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
|
||||
}
|
||||
}
|
||||
// $example off$
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val conf = new SparkConf().setAppName("ALSExample")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// $example on$
|
||||
val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
|
||||
.map(Rating.parseRating)
|
||||
.toDF()
|
||||
val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))
|
||||
|
||||
// Build the recommendation model using ALS on the training data
|
||||
val als = new ALS()
|
||||
.setMaxIter(5)
|
||||
.setRegParam(0.01)
|
||||
.setUserCol("userId")
|
||||
.setItemCol("movieId")
|
||||
.setRatingCol("rating")
|
||||
val model = als.fit(training)
|
||||
|
||||
// Evaluate the model by computing the RMSE on the test data
|
||||
val predictions = model.transform(test)
|
||||
.withColumn("rating", col("rating").cast(DoubleType))
|
||||
.withColumn("prediction", col("prediction").cast(DoubleType))
|
||||
|
||||
val evaluator = new RegressionEvaluator()
|
||||
.setMetricName("rmse")
|
||||
.setLabelCol("rating")
|
||||
.setPredictionCol("prediction")
|
||||
val rmse = evaluator.evaluate(predictions)
|
||||
println(s"Root-mean-square error = $rmse")
|
||||
// $example off$
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
// scalastyle:on println
|
||||
|
|
@ -1,182 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// scalastyle:off println
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import scopt.OptionParser
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.examples.mllib.AbstractParams
|
||||
import org.apache.spark.ml.recommendation.ALS
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
|
||||
/**
|
||||
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.MovieLensALS
|
||||
* }}}
|
||||
*/
|
||||
object MovieLensALS {
|
||||
|
||||
case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
|
||||
|
||||
object Rating {
|
||||
def parseRating(str: String): Rating = {
|
||||
val fields = str.split("::")
|
||||
assert(fields.size == 4)
|
||||
Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
|
||||
}
|
||||
}
|
||||
|
||||
case class Movie(movieId: Int, title: String, genres: Seq[String])
|
||||
|
||||
object Movie {
|
||||
def parseMovie(str: String): Movie = {
|
||||
val fields = str.split("::")
|
||||
assert(fields.size == 3)
|
||||
Movie(fields(0).toInt, fields(1), fields(2).split("\\|"))
|
||||
}
|
||||
}
|
||||
|
||||
case class Params(
|
||||
ratings: String = null,
|
||||
movies: String = null,
|
||||
maxIter: Int = 10,
|
||||
regParam: Double = 0.1,
|
||||
rank: Int = 10,
|
||||
numBlocks: Int = 10) extends AbstractParams[Params]
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val defaultParams = Params()
|
||||
|
||||
val parser = new OptionParser[Params]("MovieLensALS") {
|
||||
head("MovieLensALS: an example app for ALS on MovieLens data.")
|
||||
opt[String]("ratings")
|
||||
.required()
|
||||
.text("path to a MovieLens dataset of ratings")
|
||||
.action((x, c) => c.copy(ratings = x))
|
||||
opt[String]("movies")
|
||||
.required()
|
||||
.text("path to a MovieLens dataset of movies")
|
||||
.action((x, c) => c.copy(movies = x))
|
||||
opt[Int]("rank")
|
||||
.text(s"rank, default: ${defaultParams.rank}")
|
||||
.action((x, c) => c.copy(rank = x))
|
||||
opt[Int]("maxIter")
|
||||
.text(s"max number of iterations, default: ${defaultParams.maxIter}")
|
||||
.action((x, c) => c.copy(maxIter = x))
|
||||
opt[Double]("regParam")
|
||||
.text(s"regularization parameter, default: ${defaultParams.regParam}")
|
||||
.action((x, c) => c.copy(regParam = x))
|
||||
opt[Int]("numBlocks")
|
||||
.text(s"number of blocks, default: ${defaultParams.numBlocks}")
|
||||
.action((x, c) => c.copy(numBlocks = x))
|
||||
note(
|
||||
"""
|
||||
|Example command line to run this app:
|
||||
|
|
||||
| bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
|
||||
| examples/target/scala-*/spark-examples-*.jar \
|
||||
| --rank 10 --maxIter 15 --regParam 0.1 \
|
||||
| --movies data/mllib/als/sample_movielens_movies.txt \
|
||||
| --ratings data/mllib/als/sample_movielens_ratings.txt
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
parser.parse(args, defaultParams).map { params =>
|
||||
run(params)
|
||||
} getOrElse {
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
def run(params: Params) {
|
||||
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
|
||||
|
||||
val numRatings = ratings.count()
|
||||
val numUsers = ratings.map(_.userId).distinct().count()
|
||||
val numMovies = ratings.map(_.movieId).distinct().count()
|
||||
|
||||
println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
|
||||
|
||||
val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
|
||||
val training = splits(0).cache()
|
||||
val test = splits(1).cache()
|
||||
|
||||
val numTraining = training.count()
|
||||
val numTest = test.count()
|
||||
println(s"Training: $numTraining, test: $numTest.")
|
||||
|
||||
ratings.unpersist(blocking = false)
|
||||
|
||||
val als = new ALS()
|
||||
.setUserCol("userId")
|
||||
.setItemCol("movieId")
|
||||
.setRank(params.rank)
|
||||
.setMaxIter(params.maxIter)
|
||||
.setRegParam(params.regParam)
|
||||
.setNumBlocks(params.numBlocks)
|
||||
|
||||
val model = als.fit(training.toDF())
|
||||
|
||||
val predictions = model.transform(test.toDF()).cache()
|
||||
|
||||
// Evaluate the model.
|
||||
// TODO: Create an evaluator to compute RMSE.
|
||||
val mse = predictions.select("rating", "prediction").rdd
|
||||
.flatMap { case Row(rating: Float, prediction: Float) =>
|
||||
val err = rating.toDouble - prediction
|
||||
val err2 = err * err
|
||||
if (err2.isNaN) {
|
||||
None
|
||||
} else {
|
||||
Some(err2)
|
||||
}
|
||||
}.mean()
|
||||
val rmse = math.sqrt(mse)
|
||||
println(s"Test RMSE = $rmse.")
|
||||
|
||||
// Inspect false positives.
|
||||
// Note: We reference columns in 2 ways:
|
||||
// (1) predictions("movieId") lets us specify the movieId column in the predictions
|
||||
// DataFrame, rather than the movieId column in the movies DataFrame.
|
||||
// (2) $"userId" specifies the userId column in the predictions DataFrame.
|
||||
// We could also write predictions("userId") but do not have to since
|
||||
// the movies DataFrame does not have a column "userId."
|
||||
val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF()
|
||||
val falsePositives = predictions.join(movies)
|
||||
.where((predictions("movieId") === movies("movieId"))
|
||||
&& ($"rating" <= 1) && ($"prediction" >= 4))
|
||||
.select($"userId", predictions("movieId"), $"title", $"rating", $"prediction")
|
||||
val numFalsePositives = falsePositives.count()
|
||||
println(s"Found $numFalsePositives false positives")
|
||||
if (numFalsePositives > 0) {
|
||||
println(s"Example false positives:")
|
||||
falsePositives.limit(100).collect().foreach(println)
|
||||
}
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
// scalastyle:on println
|
|
@ -496,7 +496,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Solves a nonnegative least squares problem with L2 regularizatin:
|
||||
* Solves a nonnegative least squares problem with L2 regularization:
|
||||
*
|
||||
* min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^
|
||||
* subject to x >= 0
|
||||
|
|
Loading…
Reference in a new issue