2015-05-08 20:24:32 -04:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2018-03-08 06:38:34 -05:00
|
|
|
import sys
|
|
|
|
|
2016-04-20 13:32:01 -04:00
|
|
|
from pyspark import since, keyword_only
|
2020-08-30 22:23:31 -04:00
|
|
|
from pyspark.ml.param.shared import HasPredictionCol, HasBlockSize, HasMaxIter, HasRegParam, \
|
|
|
|
HasCheckpointInterval, HasSeed
|
2015-05-08 20:24:32 -04:00
|
|
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel
|
2016-06-13 22:59:53 -04:00
|
|
|
from pyspark.ml.common import inherit_doc
|
2020-08-30 22:23:31 -04:00
|
|
|
from pyspark.ml.param import Params, TypeConverters, Param
|
|
|
|
from pyspark.ml.util import JavaMLWritable, JavaMLReadable
|
2015-05-08 20:24:32 -04:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ALS', 'ALSModel']
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
2020-02-09 00:14:30 -05:00
|
|
|
class _ALSModelParams(HasPredictionCol, HasBlockSize):
|
2019-10-08 02:05:09 -04:00
|
|
|
"""
|
|
|
|
Params for :py:class:`ALS` and :py:class:`ALSModel`.
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
|
|
|
|
"the integer value range.", typeConverter=TypeConverters.toString)
|
|
|
|
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
|
|
|
|
"the integer value range.", typeConverter=TypeConverters.toString)
|
|
|
|
coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
|
|
|
|
"unknown or new users/items at prediction time. This may be useful " +
|
|
|
|
"in cross-validation or production scenarios, for handling " +
|
|
|
|
"user/item ids the model has not seen in the training data. " +
|
|
|
|
"Supported values: 'nan', 'drop'.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
|
2020-08-03 11:50:34 -04:00
|
|
|
def __init__(self, *args):
|
|
|
|
super(_ALSModelParams, self).__init__(*args)
|
|
|
|
self._setDefault(blockSize=4096)
|
|
|
|
|
2019-10-08 02:05:09 -04:00
|
|
|
@since("1.4.0")
|
|
|
|
def getUserCol(self):
|
|
|
|
"""
|
|
|
|
Gets the value of userCol or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.userCol)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getItemCol(self):
|
|
|
|
"""
|
|
|
|
Gets the value of itemCol or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.itemCol)
|
|
|
|
|
|
|
|
@since("2.2.0")
|
|
|
|
def getColdStartStrategy(self):
|
|
|
|
"""
|
|
|
|
Gets the value of coldStartStrategy or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.coldStartStrategy)
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed):
|
|
|
|
"""
|
|
|
|
Params for :py:class:`ALS`.
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
rank = Param(Params._dummy(), "rank", "rank of the factorization",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference",
|
|
|
|
typeConverter=TypeConverters.toBoolean)
|
|
|
|
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
|
|
|
|
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
nonnegative = Param(Params._dummy(), "nonnegative",
|
|
|
|
"whether to use nonnegative constraint for least squares",
|
|
|
|
typeConverter=TypeConverters.toBoolean)
|
|
|
|
intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel",
|
|
|
|
"StorageLevel for intermediate datasets. Cannot be 'NONE'.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
|
|
|
|
"StorageLevel for ALS model factors.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
|
2020-08-03 11:50:34 -04:00
|
|
|
def __init__(self, *args):
|
|
|
|
super(_ALSParams, self).__init__(*args)
|
|
|
|
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
|
|
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
|
|
ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
|
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
|
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
|
|
|
|
|
2019-10-08 02:05:09 -04:00
|
|
|
@since("1.4.0")
|
|
|
|
def getRank(self):
|
|
|
|
"""
|
|
|
|
Gets the value of rank or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.rank)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getNumUserBlocks(self):
|
|
|
|
"""
|
|
|
|
Gets the value of numUserBlocks or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.numUserBlocks)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getNumItemBlocks(self):
|
|
|
|
"""
|
|
|
|
Gets the value of numItemBlocks or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.numItemBlocks)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getImplicitPrefs(self):
|
|
|
|
"""
|
|
|
|
Gets the value of implicitPrefs or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.implicitPrefs)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getAlpha(self):
|
|
|
|
"""
|
|
|
|
Gets the value of alpha or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.alpha)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getRatingCol(self):
|
|
|
|
"""
|
|
|
|
Gets the value of ratingCol or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.ratingCol)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getNonnegative(self):
|
|
|
|
"""
|
|
|
|
Gets the value of nonnegative or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.nonnegative)
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def getIntermediateStorageLevel(self):
|
|
|
|
"""
|
|
|
|
Gets the value of intermediateStorageLevel or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.intermediateStorageLevel)
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def getFinalStorageLevel(self):
|
|
|
|
"""
|
|
|
|
Gets the value of finalStorageLevel or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.finalStorageLevel)
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
|
|
|
Alternating Least Squares (ALS) matrix factorization.
|
|
|
|
|
|
|
|
ALS attempts to estimate the ratings matrix `R` as the product of
|
|
|
|
two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically
|
|
|
|
these approximations are called 'factor' matrices. The general
|
|
|
|
approach is iterative. During each iteration, one of the factor
|
|
|
|
matrices is held constant, while the other is solved for using least
|
|
|
|
squares. The newly-solved factor matrix is then held constant while
|
|
|
|
solving for the other factor matrix.
|
|
|
|
|
|
|
|
This is a blocked implementation of the ALS factorization algorithm
|
|
|
|
that groups the two sets of factors (referred to as "users" and
|
|
|
|
"products") into blocks and reduces communication by only sending
|
|
|
|
one copy of each user vector to each product block on each
|
|
|
|
iteration, and only for the product blocks that need that user's
|
|
|
|
feature vector. This is achieved by pre-computing some information
|
|
|
|
about the ratings matrix to determine the "out-links" of each user
|
|
|
|
(which blocks of products it will contribute to) and "in-link"
|
|
|
|
information for each product (which of the feature vectors it
|
|
|
|
receives from each user block it will depend on). This allows us to
|
|
|
|
send only an array of feature vectors between each user block and
|
|
|
|
product block, and have the product block find the users' ratings
|
|
|
|
and update the products based on these messages.
|
|
|
|
|
|
|
|
For implicit preference data, the algorithm used is based on
|
2016-05-09 04:11:17 -04:00
|
|
|
`"Collaborative Filtering for Implicit Feedback Datasets",
|
2018-11-25 18:43:55 -05:00
|
|
|
<https://doi.org/10.1109/ICDM.2008.22>`_, adapted for the blocked
|
2015-05-08 20:24:32 -04:00
|
|
|
approach used here.
|
|
|
|
|
|
|
|
Essentially instead of finding the low-rank approximations to the
|
|
|
|
rating matrix `R`, this finds the approximations for a preference
|
|
|
|
matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0.
|
|
|
|
The ratings then act as 'confidence' values related to strength of
|
|
|
|
indicated user preferences rather than explicit ratings given to
|
|
|
|
items.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
.. versionadded:: 1.4.0
|
2019-09-18 10:22:13 -04:00
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
The input rating dataframe to the ALS implementation should be deterministic.
|
|
|
|
Nondeterministic data can cause failure during fitting ALS model.
|
|
|
|
For example, an order-sensitive operation like sampling after a repartition makes
|
|
|
|
dataframe output nondeterministic, like `df.repartition(2).sample(False, 0.5, 1618)`.
|
|
|
|
Checkpointing sampled dataframe or adding a sort before sampling can help make the
|
|
|
|
dataframe deterministic.
|
|
|
|
|
|
|
|
Examples
|
|
|
|
--------
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame(
|
2015-05-29 01:38:38 -04:00
|
|
|
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
|
|
|
|
... ["user", "item", "rating"])
|
2019-10-27 23:36:10 -04:00
|
|
|
>>> als = ALS(rank=10, seed=0)
|
|
|
|
>>> als.setMaxIter(5)
|
|
|
|
ALS...
|
|
|
|
>>> als.getMaxIter()
|
|
|
|
5
|
|
|
|
>>> als.setRegParam(0.1)
|
|
|
|
ALS...
|
|
|
|
>>> als.getRegParam()
|
|
|
|
0.1
|
|
|
|
>>> als.clear(als.regParam)
|
2015-05-08 20:24:32 -04:00
|
|
|
>>> model = als.fit(df)
|
2020-02-09 00:14:30 -05:00
|
|
|
>>> model.getBlockSize()
|
|
|
|
4096
|
2019-10-08 02:05:09 -04:00
|
|
|
>>> model.getUserCol()
|
|
|
|
'user'
|
[SPARK-29867][ML][PYTHON] Add __repr__ in Python ML Models
### What changes were proposed in this pull request?
Add ```__repr__``` in Python ML Models
### Why are the changes needed?
In Python ML Models, some of them have ```__repr__```, others don't. In the doctest, when calling Model.setXXX, some of the Models print out the xxxModel... correctly, some of them can't because of lacking the ```__repr__``` method. For example:
```
>>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
>>> model = gm.fit(df)
>>> model.setPredictionCol("newPrediction")
GaussianMixture...
```
After the change, the above code will become the following:
```
>>> gm = GaussianMixture(k=3, tol=0.0001, seed=10)
>>> model = gm.fit(df)
>>> model.setPredictionCol("newPrediction")
GaussianMixtureModel...
```
### Does this PR introduce any user-facing change?
Yes.
### How was this patch tested?
doctest
Closes #26489 from huaxingao/spark-29876.
Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
2019-11-16 00:44:39 -05:00
|
|
|
>>> model.setUserCol("user")
|
|
|
|
ALSModel...
|
2019-10-08 02:05:09 -04:00
|
|
|
>>> model.getItemCol()
|
|
|
|
'item'
|
|
|
|
>>> model.setPredictionCol("newPrediction")
|
|
|
|
ALS...
|
2015-05-29 01:38:38 -04:00
|
|
|
>>> model.rank
|
|
|
|
10
|
|
|
|
>>> model.userFactors.orderBy("id").collect()
|
|
|
|
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
|
2015-05-08 20:24:32 -04:00
|
|
|
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
|
|
|
|
>>> predictions[0]
|
[SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib
### What changes were proposed in this pull request?
Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package.
The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation.
Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available.
A table summarising which version gets loaded in which case:
```
| | BLAS.nativeBLAS | BLAS.javaBLAS |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS |
| | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires |
| | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) |
| | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) |
| | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS |
| | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a |
| | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core |
| | org.apache.spark.ml.linalg.BLAS | |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS |
| | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires |
| | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) |
| | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) |
| | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS |
| | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a |
| | | wrapper for com.github.fommil:core |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
```
### Why are the changes needed?
Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available.
### Does this PR introduce _any_ user-facing change?
No, all changes are transparent to the user.
### How was this patch tested?
The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite.
[1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`:
#### JDK8:
```
[info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 223 232 8 448.0 2.2 1.0X
[info] java 221 228 7 453.0 2.2 1.0X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 122 128 4 821.2 1.2 1.0X
[info] java 122 128 4 822.3 1.2 1.0X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 109 112 2 921.4 1.1 1.0X
[info] java 70 74 3 1423.5 0.7 1.5X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 98 2 1046.1 1.0 1.0X
[info] java 47 49 2 2121.7 0.5 2.0X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 184 195 8 544.3 1.8 1.0X
[info] java 185 196 7 539.5 1.9 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 99 104 4 1011.9 1.0 1.0X
[info] java 99 104 4 1010.4 1.0 1.0X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 947.2 1.1 1.0X
[info] java 0 0 0 1584.8 0.6 1.7X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 867.4 1.2 1.0X
[info] java 1 1 0 865.0 1.2 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 485.9 2.1 1.0X
[info] java 1 1 0 486.8 2.1 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1843.0 0.5 1.0X
[info] java 0 0 0 2690.6 0.4 1.5X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1214.7 0.8 1.0X
[info] java 0 0 0 2536.8 0.4 2.1X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1895.9 0.5 1.0X
[info] java 0 0 0 2961.1 0.3 1.6X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1223.4 0.8 1.0X
[info] java 0 0 0 3091.4 0.3 2.5X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 560 575 20 1787.1 0.6 1.0X
[info] java 226 232 5 4432.4 0.2 2.5X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 570 586 23 1755.2 0.6 1.0X
[info] java 227 232 4 4410.1 0.2 2.5X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 863 879 17 1158.4 0.9 1.0X
[info] java 227 231 3 4407.9 0.2 3.8X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1282 1305 23 780.0 1.3 1.0X
[info] java 227 232 4 4413.4 0.2 5.7X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 538 548 8 1858.6 0.5 1.0X
[info] java 221 226 3 4521.1 0.2 2.4X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 549 558 10 1819.9 0.5 1.0X
[info] java 222 229 7 4503.5 0.2 2.5X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 838 852 12 1193.0 0.8 1.0X
[info] java 222 229 5 4500.5 0.2 3.8X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 905 919 18 1104.8 0.9 1.0X
[info] java 221 228 5 4521.3 0.2 4.1X
```
#### JDK11:
```
[info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 195 204 10 512.7 2.0 1.0X
[info] java 195 202 7 512.4 2.0 1.0X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 108 113 4 923.3 1.1 1.0X
[info] java 102 107 4 984.4 1.0 1.1X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 107 110 3 938.1 1.1 1.0X
[info] java 69 72 3 1447.1 0.7 1.5X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 98 2 1046.5 1.0 1.0X
[info] java 43 45 2 2317.1 0.4 2.2X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 155 168 8 644.2 1.6 1.0X
[info] java 158 169 8 632.8 1.6 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 85 90 4 1178.1 0.8 1.0X
[info] java 86 90 4 1167.7 0.9 1.0X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 1182.1 0.8 1.0X
[info] java 0 0 0 1432.1 0.7 1.2X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 898.7 1.1 1.0X
[info] java 1 1 0 891.5 1.1 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 495.4 2.0 1.0X
[info] java 1 1 0 495.7 2.0 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2271.6 0.4 1.0X
[info] java 0 0 0 3648.1 0.3 1.6X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1229.3 0.8 1.0X
[info] java 0 0 0 2711.3 0.4 2.2X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2677.5 0.4 1.0X
[info] java 0 0 0 3288.2 0.3 1.2X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1233.0 0.8 1.0X
[info] java 0 0 0 2766.3 0.4 2.2X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 520 536 16 1923.6 0.5 1.0X
[info] java 214 221 7 4669.5 0.2 2.4X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 593 612 17 1686.5 0.6 1.0X
[info] java 215 219 3 4643.3 0.2 2.8X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 853 870 16 1172.8 0.9 1.0X
[info] java 215 218 3 4659.7 0.2 4.0X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1350 1370 23 740.8 1.3 1.0X
[info] java 215 219 4 4656.6 0.2 6.3X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 460 468 6 2173.2 0.5 1.0X
[info] java 210 213 2 4752.7 0.2 2.2X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 535 544 8 1869.3 0.5 1.0X
[info] java 210 215 5 4761.8 0.2 2.5X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 843 853 11 1186.8 0.8 1.0X
[info] java 209 214 4 4793.4 0.2 4.0X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 891 904 15 1122.0 0.9 1.0X
[info] java 209 214 4 4777.2 0.2 4.3X
```
#### JDK16:
```
[info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 194 199 7 515.7 1.9 1.0X
[info] java 181 186 3 551.1 1.8 1.1X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 109 115 4 915.0 1.1 1.0X
[info] java 88 92 3 1138.8 0.9 1.2X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 108 110 2 922.6 1.1 1.0X
[info] java 54 56 2 1839.2 0.5 2.0X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 97 2 1046.1 1.0 1.0X
[info] java 29 30 1 3393.4 0.3 3.2X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 156 165 5 643.0 1.6 1.0X
[info] java 150 159 5 667.1 1.5 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 85 91 6 1171.0 0.9 1.0X
[info] java 75 79 3 1340.6 0.7 1.1X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 917.0 1.1 1.0X
[info] java 0 0 0 8147.2 0.1 8.9X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 859.3 1.2 1.0X
[info] java 1 1 0 859.3 1.2 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 482.1 2.1 1.0X
[info] java 1 1 0 482.6 2.1 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2214.2 0.5 1.0X
[info] java 0 0 0 7975.8 0.1 3.6X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1231.4 0.8 1.0X
[info] java 0 0 0 8680.9 0.1 7.0X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2684.3 0.4 1.0X
[info] java 0 0 0 18527.1 0.1 6.9X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1235.4 0.8 1.0X
[info] java 0 0 0 17347.9 0.1 14.0X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 530 552 18 1887.5 0.5 1.0X
[info] java 58 64 3 17143.9 0.1 9.1X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 598 620 17 1671.1 0.6 1.0X
[info] java 58 64 3 17196.6 0.1 10.3X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 834 847 14 1199.4 0.8 1.0X
[info] java 57 63 4 17486.9 0.1 14.6X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1338 1366 22 747.3 1.3 1.0X
[info] java 58 63 3 17356.6 0.1 23.2X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 489 501 9 2045.5 0.5 1.0X
[info] java 36 38 2 27721.9 0.0 13.6X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 478 488 9 2094.0 0.5 1.0X
[info] java 36 38 2 27813.2 0.0 13.3X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 825 837 10 1211.6 0.8 1.0X
[info] java 35 38 2 28433.1 0.0 23.5X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 900 918 15 1111.6 0.9 1.0X
[info] java 36 38 2 28073.0 0.0 25.3X
```
[2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas
Closes #32253 from luhenry/master.
Authored-by: Ludovic Henry <git@ludovic.dev>
Signed-off-by: Sean Owen <srowen@gmail.com>
2021-04-27 15:00:59 -04:00
|
|
|
Row(user=0, item=2, newPrediction=0.69291...)
|
2015-05-08 20:24:32 -04:00
|
|
|
>>> predictions[1]
|
[SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib
### What changes were proposed in this pull request?
Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package.
The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation.
Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available.
A table summarising which version gets loaded in which case:
```
| | BLAS.nativeBLAS | BLAS.javaBLAS |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS |
| | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires |
| | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) |
| | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) |
| | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS |
| | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a |
| | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core |
| | org.apache.spark.ml.linalg.BLAS | |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
| without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS |
| | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires |
| | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) |
| | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) |
| | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS |
| | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a |
| | | wrapper for com.github.fommil:core |
| --------------------- | -------------------------------------------------- | -------------------------------------------------- |
```
### Why are the changes needed?
Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available.
### Does this PR introduce _any_ user-facing change?
No, all changes are transparent to the user.
### How was this patch tested?
The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite.
[1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`:
#### JDK8:
```
[info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 223 232 8 448.0 2.2 1.0X
[info] java 221 228 7 453.0 2.2 1.0X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 122 128 4 821.2 1.2 1.0X
[info] java 122 128 4 822.3 1.2 1.0X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 109 112 2 921.4 1.1 1.0X
[info] java 70 74 3 1423.5 0.7 1.5X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 98 2 1046.1 1.0 1.0X
[info] java 47 49 2 2121.7 0.5 2.0X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 184 195 8 544.3 1.8 1.0X
[info] java 185 196 7 539.5 1.9 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 99 104 4 1011.9 1.0 1.0X
[info] java 99 104 4 1010.4 1.0 1.0X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 947.2 1.1 1.0X
[info] java 0 0 0 1584.8 0.6 1.7X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 867.4 1.2 1.0X
[info] java 1 1 0 865.0 1.2 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 485.9 2.1 1.0X
[info] java 1 1 0 486.8 2.1 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1843.0 0.5 1.0X
[info] java 0 0 0 2690.6 0.4 1.5X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1214.7 0.8 1.0X
[info] java 0 0 0 2536.8 0.4 2.1X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1895.9 0.5 1.0X
[info] java 0 0 0 2961.1 0.3 1.6X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1223.4 0.8 1.0X
[info] java 0 0 0 3091.4 0.3 2.5X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 560 575 20 1787.1 0.6 1.0X
[info] java 226 232 5 4432.4 0.2 2.5X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 570 586 23 1755.2 0.6 1.0X
[info] java 227 232 4 4410.1 0.2 2.5X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 863 879 17 1158.4 0.9 1.0X
[info] java 227 231 3 4407.9 0.2 3.8X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1282 1305 23 780.0 1.3 1.0X
[info] java 227 232 4 4413.4 0.2 5.7X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 538 548 8 1858.6 0.5 1.0X
[info] java 221 226 3 4521.1 0.2 2.4X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 549 558 10 1819.9 0.5 1.0X
[info] java 222 229 7 4503.5 0.2 2.5X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 838 852 12 1193.0 0.8 1.0X
[info] java 222 229 5 4500.5 0.2 3.8X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 905 919 18 1104.8 0.9 1.0X
[info] java 221 228 5 4521.3 0.2 4.1X
```
#### JDK11:
```
[info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 195 204 10 512.7 2.0 1.0X
[info] java 195 202 7 512.4 2.0 1.0X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 108 113 4 923.3 1.1 1.0X
[info] java 102 107 4 984.4 1.0 1.1X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 107 110 3 938.1 1.1 1.0X
[info] java 69 72 3 1447.1 0.7 1.5X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 98 2 1046.5 1.0 1.0X
[info] java 43 45 2 2317.1 0.4 2.2X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 155 168 8 644.2 1.6 1.0X
[info] java 158 169 8 632.8 1.6 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 85 90 4 1178.1 0.8 1.0X
[info] java 86 90 4 1167.7 0.9 1.0X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 1182.1 0.8 1.0X
[info] java 0 0 0 1432.1 0.7 1.2X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 898.7 1.1 1.0X
[info] java 1 1 0 891.5 1.1 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 495.4 2.0 1.0X
[info] java 1 1 0 495.7 2.0 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2271.6 0.4 1.0X
[info] java 0 0 0 3648.1 0.3 1.6X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1229.3 0.8 1.0X
[info] java 0 0 0 2711.3 0.4 2.2X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2677.5 0.4 1.0X
[info] java 0 0 0 3288.2 0.3 1.2X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1233.0 0.8 1.0X
[info] java 0 0 0 2766.3 0.4 2.2X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 520 536 16 1923.6 0.5 1.0X
[info] java 214 221 7 4669.5 0.2 2.4X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 593 612 17 1686.5 0.6 1.0X
[info] java 215 219 3 4643.3 0.2 2.8X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 853 870 16 1172.8 0.9 1.0X
[info] java 215 218 3 4659.7 0.2 4.0X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1350 1370 23 740.8 1.3 1.0X
[info] java 215 219 4 4656.6 0.2 6.3X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 460 468 6 2173.2 0.5 1.0X
[info] java 210 213 2 4752.7 0.2 2.2X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 535 544 8 1869.3 0.5 1.0X
[info] java 210 215 5 4761.8 0.2 2.5X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 843 853 11 1186.8 0.8 1.0X
[info] java 209 214 4 4793.4 0.2 4.0X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 891 904 15 1122.0 0.9 1.0X
[info] java 209 214 4 4777.2 0.2 4.3X
```
#### JDK16:
```
[info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic
[info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz
[info]
[info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS
[info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS
[info]
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 194 199 7 515.7 1.9 1.0X
[info] java 181 186 3 551.1 1.8 1.1X
[info]
[info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 109 115 4 915.0 1.1 1.0X
[info] java 88 92 3 1138.8 0.9 1.2X
[info]
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 108 110 2 922.6 1.1 1.0X
[info] java 54 56 2 1839.2 0.5 2.0X
[info]
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 96 97 2 1046.1 1.0 1.0X
[info] java 29 30 1 3393.4 0.3 3.2X
[info]
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 156 165 5 643.0 1.6 1.0X
[info] java 150 159 5 667.1 1.5 1.0X
[info]
[info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 85 91 6 1171.0 0.9 1.0X
[info] java 75 79 3 1340.6 0.7 1.1X
[info]
[info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 917.0 1.1 1.0X
[info] java 0 0 0 8147.2 0.1 8.9X
[info]
[info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 859.3 1.2 1.0X
[info] java 1 1 0 859.3 1.2 1.0X
[info]
[info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 482.1 2.1 1.0X
[info] java 1 1 0 482.6 2.1 1.0X
[info]
[info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2214.2 0.5 1.0X
[info] java 0 0 0 7975.8 0.1 3.6X
[info]
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1231.4 0.8 1.0X
[info] java 0 0 0 8680.9 0.1 7.0X
[info]
[info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 0 0 0 2684.3 0.4 1.0X
[info] java 0 0 0 18527.1 0.1 6.9X
[info]
[info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1 1 0 1235.4 0.8 1.0X
[info] java 0 0 0 17347.9 0.1 14.0X
[info]
[info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 530 552 18 1887.5 0.5 1.0X
[info] java 58 64 3 17143.9 0.1 9.1X
[info]
[info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 598 620 17 1671.1 0.6 1.0X
[info] java 58 64 3 17196.6 0.1 10.3X
[info]
[info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 834 847 14 1199.4 0.8 1.0X
[info] java 57 63 4 17486.9 0.1 14.6X
[info]
[info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 1338 1366 22 747.3 1.3 1.0X
[info] java 58 63 3 17356.6 0.1 23.2X
[info]
[info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 489 501 9 2045.5 0.5 1.0X
[info] java 36 38 2 27721.9 0.0 13.6X
[info]
[info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 478 488 9 2094.0 0.5 1.0X
[info] java 36 38 2 27813.2 0.0 13.3X
[info]
[info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 825 837 10 1211.6 0.8 1.0X
[info] java 35 38 2 28433.1 0.0 23.5X
[info]
[info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] f2j 900 918 15 1111.6 0.9 1.0X
[info] java 36 38 2 28073.0 0.0 25.3X
```
[2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas
Closes #32253 from luhenry/master.
Authored-by: Ludovic Henry <git@ludovic.dev>
Signed-off-by: Sean Owen <srowen@gmail.com>
2021-04-27 15:00:59 -04:00
|
|
|
Row(user=1, item=0, newPrediction=3.47356...)
|
2015-05-08 20:24:32 -04:00
|
|
|
>>> predictions[2]
|
[SPARK-33203][PYTHON][TEST] Fix tests failing with rounding errors
### What changes were proposed in this pull request?
Increase tolerance for two tests that fail in some environments and fail in others (flaky? Pass/fail is constant within the same environment)
### Why are the changes needed?
The tests `pyspark.ml.recommendation` and `pyspark.ml.tests.test_algorithms` fail with
```
File "/home/jenkins/python/pyspark/ml/tests/test_algorithms.py", line 96, in test_raw_and_probability_prediction
self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1))
AssertionError: False is not true
```
```
File "/home/jenkins/python/pyspark/ml/recommendation.py", line 256, in _main_.ALS
Failed example:
predictions[0]
Expected:
Row(user=0, item=2, newPrediction=0.6929101347923279)
Got:
Row(user=0, item=2, newPrediction=0.6929104924201965)
...
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
This path changes a test target. Just executed the tests to verify they pass.
Closes #30104 from AlessandroPatti/apatti/rounding-errors.
Authored-by: Alessandro Patti <ale812@yahoo.it>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
2020-10-21 21:14:21 -04:00
|
|
|
Row(user=2, item=0, newPrediction=-0.899198...)
|
2017-05-02 04:49:13 -04:00
|
|
|
>>> user_recs = model.recommendForAllUsers(3)
|
|
|
|
>>> user_recs.where(user_recs.user == 0)\
|
|
|
|
.select("recommendations.item", "recommendations.rating").collect()
|
2019-03-23 12:26:09 -04:00
|
|
|
[Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]
|
2017-05-02 04:49:13 -04:00
|
|
|
>>> item_recs = model.recommendForAllItems(3)
|
|
|
|
>>> item_recs.where(item_recs.item == 2)\
|
|
|
|
.select("recommendations.user", "recommendations.rating").collect()
|
2019-03-23 12:26:09 -04:00
|
|
|
[Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]
|
2017-10-09 04:42:33 -04:00
|
|
|
>>> user_subset = df.where(df.user == 2)
|
|
|
|
>>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)
|
|
|
|
>>> user_subset_recs.select("recommendations.item", "recommendations.rating").first()
|
2019-03-23 12:26:09 -04:00
|
|
|
Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])
|
2017-10-09 04:42:33 -04:00
|
|
|
>>> item_subset = df.where(df.item == 0)
|
|
|
|
>>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)
|
|
|
|
>>> item_subset_recs.select("recommendations.user", "recommendations.rating").first()
|
2019-03-23 12:26:09 -04:00
|
|
|
Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])
|
2016-02-20 04:07:19 -05:00
|
|
|
>>> als_path = temp_path + "/als"
|
2016-02-11 18:50:33 -05:00
|
|
|
>>> als.save(als_path)
|
|
|
|
>>> als2 = ALS.load(als_path)
|
|
|
|
>>> als.getMaxIter()
|
|
|
|
5
|
2016-02-20 04:07:19 -05:00
|
|
|
>>> model_path = temp_path + "/als_model"
|
2016-02-11 18:50:33 -05:00
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = ALSModel.load(model_path)
|
|
|
|
>>> model.rank == model2.rank
|
|
|
|
True
|
|
|
|
>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
|
|
|
|
True
|
|
|
|
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
|
|
|
|
True
|
2020-08-03 11:50:34 -04:00
|
|
|
>>> model.transform(test).take(1) == model2.transform(test).take(1)
|
|
|
|
True
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
2015-05-18 15:02:18 -04:00
|
|
|
|
2015-05-08 20:24:32 -04:00
|
|
|
@keyword_only
|
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods
### What changes were proposed in this pull request?
This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/).
__Note__:
For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579
### Why are the changes needed?
Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that.
In practice we could probably capture `locals` and drop `keyword_only` completel, i.e:
```python
keyword_only
def __init__(self, *, featuresCol="features"):
...
kwargs = self._input_kwargs
self.setParams(**kwargs)
```
could be replaced with
```python
def __init__(self, *, featuresCol="features"):
kwargs = locals()
del kwargs["self"]
...
self.setParams(**kwargs)
```
### Does this PR introduce _any_ user-facing change?
Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments.
For example with ` LinearSVC` will change from
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__(
self,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
)
Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2):
File: /path/to/python/pyspark/ml/classification.py
Type: function
```
to
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__ (
self,
*,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
blockSize=1,
)
Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1):
File: ~/Workspace/spark/python/pyspark/ml/classification.py
Type: function
```
### How was this patch tested?
Existing tests.
Closes #29799 from zero323/SPARK-32933.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
|
|
|
def __init__(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
2016-04-30 03:41:28 -04:00
|
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
2020-02-09 00:14:30 -05:00
|
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods
### What changes were proposed in this pull request?
This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/).
__Note__:
For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579
### Why are the changes needed?
Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that.
In practice we could probably capture `locals` and drop `keyword_only` completel, i.e:
```python
keyword_only
def __init__(self, *, featuresCol="features"):
...
kwargs = self._input_kwargs
self.setParams(**kwargs)
```
could be replaced with
```python
def __init__(self, *, featuresCol="features"):
kwargs = locals()
del kwargs["self"]
...
self.setParams(**kwargs)
```
### Does this PR introduce _any_ user-facing change?
Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments.
For example with ` LinearSVC` will change from
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__(
self,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
)
Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2):
File: /path/to/python/pyspark/ml/classification.py
Type: function
```
to
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__ (
self,
*,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
blockSize=1,
)
Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1):
File: ~/Workspace/spark/python/pyspark/ml/classification.py
Type: function
```
### How was this patch tested?
Existing tests.
Closes #29799 from zero323/SPARK-32933.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
|
|
|
__init__(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
|
|
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
|
2016-04-30 03:41:28 -04:00
|
|
|
intermediateStorageLevel="MEMORY_AND_DISK", \
|
2020-02-09 00:14:30 -05:00
|
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
|
|
|
super(ALS, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-08 20:24:32 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods
### What changes were proposed in this pull request?
This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/).
__Note__:
For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579
### Why are the changes needed?
Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that.
In practice we could probably capture `locals` and drop `keyword_only` completel, i.e:
```python
keyword_only
def __init__(self, *, featuresCol="features"):
...
kwargs = self._input_kwargs
self.setParams(**kwargs)
```
could be replaced with
```python
def __init__(self, *, featuresCol="features"):
kwargs = locals()
del kwargs["self"]
...
self.setParams(**kwargs)
```
### Does this PR introduce _any_ user-facing change?
Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments.
For example with ` LinearSVC` will change from
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__(
self,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
)
Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2):
File: /path/to/python/pyspark/ml/classification.py
Type: function
```
to
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__ (
self,
*,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
blockSize=1,
)
Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1):
File: ~/Workspace/spark/python/pyspark/ml/classification.py
Type: function
```
### How was this patch tested?
Existing tests.
Closes #29799 from zero323/SPARK-32933.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
|
|
|
def setParams(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
2016-04-30 03:41:28 -04:00
|
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
2020-02-09 00:14:30 -05:00
|
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods
### What changes were proposed in this pull request?
This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/).
__Note__:
For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579
### Why are the changes needed?
Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that.
In practice we could probably capture `locals` and drop `keyword_only` completel, i.e:
```python
keyword_only
def __init__(self, *, featuresCol="features"):
...
kwargs = self._input_kwargs
self.setParams(**kwargs)
```
could be replaced with
```python
def __init__(self, *, featuresCol="features"):
kwargs = locals()
del kwargs["self"]
...
self.setParams(**kwargs)
```
### Does this PR introduce _any_ user-facing change?
Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments.
For example with ` LinearSVC` will change from
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__(
self,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
)
Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2):
File: /path/to/python/pyspark/ml/classification.py
Type: function
```
to
```
>>> from pyspark.ml.classification import LinearSVC
>>> ?LinearSVC.__init__
Signature:
LinearSVC.__init__ (
self,
*,
featuresCol='features',
labelCol='label',
predictionCol='prediction',
maxIter=100,
regParam=0.0,
tol=1e-06,
rawPredictionCol='rawPrediction',
fitIntercept=True,
standardization=True,
threshold=0.0,
weightCol=None,
aggregationDepth=2,
blockSize=1,
)
Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1):
File: ~/Workspace/spark/python/pyspark/ml/classification.py
Type: function
```
### How was this patch tested?
Existing tests.
Closes #29799 from zero323/SPARK-32933.
Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
|
|
|
setParams(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, \
|
|
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
|
|
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
|
2016-04-30 03:41:28 -04:00
|
|
|
intermediateStorageLevel="MEMORY_AND_DISK", \
|
2020-02-09 00:14:30 -05:00
|
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
|
2015-05-08 20:24:32 -04:00
|
|
|
Sets params for ALS.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-08 20:24:32 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return ALSModel(java_model)
|
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setRank(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`rank`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(rank=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setNumUserBlocks(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`numUserBlocks`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(numUserBlocks=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setNumItemBlocks(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`numItemBlocks`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(numItemBlocks=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setNumBlocks(self, value):
|
|
|
|
"""
|
|
|
|
Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
|
|
|
|
"""
|
2016-04-15 15:14:41 -04:00
|
|
|
self._set(numUserBlocks=value)
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(numItemBlocks=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setImplicitPrefs(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`implicitPrefs`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(implicitPrefs=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setAlpha(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`alpha`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(alpha=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setUserCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`userCol`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(userCol=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setItemCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`itemCol`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(itemCol=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setRatingCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`ratingCol`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(ratingCol=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-08 20:24:32 -04:00
|
|
|
def setNonnegative(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`nonnegative`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(nonnegative=value)
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2016-04-30 01:01:41 -04:00
|
|
|
@since("2.0.0")
|
2016-04-30 03:41:28 -04:00
|
|
|
def setIntermediateStorageLevel(self, value):
|
2016-04-30 01:01:41 -04:00
|
|
|
"""
|
2016-04-30 03:41:28 -04:00
|
|
|
Sets the value of :py:attr:`intermediateStorageLevel`.
|
2016-04-30 01:01:41 -04:00
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(intermediateStorageLevel=value)
|
2016-04-30 01:01:41 -04:00
|
|
|
|
|
|
|
@since("2.0.0")
|
2016-04-30 03:41:28 -04:00
|
|
|
def setFinalStorageLevel(self, value):
|
2016-04-30 01:01:41 -04:00
|
|
|
"""
|
2016-04-30 03:41:28 -04:00
|
|
|
Sets the value of :py:attr:`finalStorageLevel`.
|
2016-04-30 01:01:41 -04:00
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(finalStorageLevel=value)
|
2016-04-30 01:01:41 -04:00
|
|
|
|
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation.
The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible.
The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)).
## How was this patch tested?
New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics.
Author: Nick Pentreath <nickp@za.ibm.com>
Closes #12896 from MLnick/SPARK-14489-als-nan.
2017-02-28 09:17:35 -05:00
|
|
|
@since("2.2.0")
|
|
|
|
def setColdStartStrategy(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`coldStartStrategy`.
|
|
|
|
"""
|
|
|
|
return self._set(coldStartStrategy=value)
|
|
|
|
|
2019-10-27 23:36:10 -04:00
|
|
|
def setMaxIter(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxIter`.
|
|
|
|
"""
|
|
|
|
return self._set(maxIter=value)
|
|
|
|
|
|
|
|
def setRegParam(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`regParam`.
|
|
|
|
"""
|
|
|
|
return self._set(regParam=value)
|
|
|
|
|
|
|
|
def setPredictionCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`predictionCol`.
|
|
|
|
"""
|
|
|
|
return self._set(predictionCol=value)
|
|
|
|
|
|
|
|
def setCheckpointInterval(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`checkpointInterval`.
|
|
|
|
"""
|
|
|
|
return self._set(checkpointInterval=value)
|
|
|
|
|
|
|
|
def setSeed(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`seed`.
|
|
|
|
"""
|
|
|
|
return self._set(seed=value)
|
|
|
|
|
2020-02-09 00:14:30 -05:00
|
|
|
@since("3.0.0")
|
|
|
|
def setBlockSize(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`blockSize`.
|
|
|
|
"""
|
|
|
|
return self._set(blockSize=value)
|
|
|
|
|
2015-05-08 20:24:32 -04:00
|
|
|
|
2019-10-08 02:05:09 -04:00
|
|
|
class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable):
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
|
|
|
Model fitted by ALS.
|
2015-09-17 11:51:19 -04:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-08 20:24:32 -04:00
|
|
|
"""
|
|
|
|
|
2019-10-08 02:05:09 -04:00
|
|
|
@since("3.0.0")
|
|
|
|
def setUserCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`userCol`.
|
|
|
|
"""
|
|
|
|
return self._set(userCol=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setItemCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`itemCol`.
|
|
|
|
"""
|
|
|
|
return self._set(itemCol=value)
|
|
|
|
|
|
|
|
@since("3.0.0")
|
|
|
|
def setColdStartStrategy(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`coldStartStrategy`.
|
|
|
|
"""
|
|
|
|
return self._set(coldStartStrategy=value)
|
|
|
|
|
2019-10-27 23:36:10 -04:00
|
|
|
@since("3.0.0")
|
|
|
|
def setPredictionCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`predictionCol`.
|
|
|
|
"""
|
|
|
|
return self._set(predictionCol=value)
|
|
|
|
|
2020-02-09 00:14:30 -05:00
|
|
|
@since("3.0.0")
|
|
|
|
def setBlockSize(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`blockSize`.
|
|
|
|
"""
|
|
|
|
return self._set(blockSize=value)
|
|
|
|
|
2015-05-29 01:38:38 -04:00
|
|
|
@property
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-29 01:38:38 -04:00
|
|
|
def rank(self):
|
|
|
|
"""rank of the matrix factorization model"""
|
|
|
|
return self._call_java("rank")
|
|
|
|
|
|
|
|
@property
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-29 01:38:38 -04:00
|
|
|
def userFactors(self):
|
|
|
|
"""
|
|
|
|
a DataFrame that stores user factors in two columns: `id` and
|
|
|
|
`features`
|
|
|
|
"""
|
|
|
|
return self._call_java("userFactors")
|
|
|
|
|
|
|
|
@property
|
2015-09-17 11:51:19 -04:00
|
|
|
@since("1.4.0")
|
2015-05-29 01:38:38 -04:00
|
|
|
def itemFactors(self):
|
|
|
|
"""
|
|
|
|
a DataFrame that stores item factors in two columns: `id` and
|
|
|
|
`features`
|
|
|
|
"""
|
|
|
|
return self._call_java("itemFactors")
|
|
|
|
|
2017-05-02 04:49:13 -04:00
|
|
|
def recommendForAllUsers(self, numItems):
|
|
|
|
"""
|
|
|
|
Returns top `numItems` items recommended for each user, for all users.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
numItems : int
|
|
|
|
max number of recommendations for each user
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame of (userCol, recommendations), where recommendations are
|
|
|
|
stored as an array of (itemCol, rating) Rows.
|
2017-05-02 04:49:13 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("recommendForAllUsers", numItems)
|
|
|
|
|
|
|
|
def recommendForAllItems(self, numUsers):
|
|
|
|
"""
|
|
|
|
Returns top `numUsers` users recommended for each item, for all items.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
numUsers : int
|
|
|
|
max number of recommendations for each item
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame of (itemCol, recommendations), where recommendations are
|
|
|
|
stored as an array of (userCol, rating) Rows.
|
2017-05-02 04:49:13 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("recommendForAllItems", numUsers)
|
|
|
|
|
2017-10-09 04:42:33 -04:00
|
|
|
def recommendForUserSubset(self, dataset, numItems):
|
|
|
|
"""
|
|
|
|
Returns top `numItems` items recommended for each user id in the input data set. Note that
|
|
|
|
if there are duplicate ids in the input dataset, only one set of recommendations per unique
|
|
|
|
id will be returned.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
dataset : :py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame containing a column of user ids. The column name must match `userCol`.
|
|
|
|
numItems : int
|
|
|
|
max number of recommendations for each user
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame of (userCol, recommendations), where recommendations are
|
|
|
|
stored as an array of (itemCol, rating) Rows.
|
2017-10-09 04:42:33 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("recommendForUserSubset", dataset, numItems)
|
|
|
|
|
|
|
|
def recommendForItemSubset(self, dataset, numUsers):
|
|
|
|
"""
|
|
|
|
Returns top `numUsers` users recommended for each item id in the input data set. Note that
|
|
|
|
if there are duplicate ids in the input dataset, only one set of recommendations per unique
|
|
|
|
id will be returned.
|
|
|
|
|
2020-11-09 19:33:48 -05:00
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
dataset : :py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame containing a column of item ids. The column name must match `itemCol`.
|
|
|
|
numUsers : int
|
|
|
|
max number of recommendations for each item
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
|
|
a DataFrame of (itemCol, recommendations), where recommendations are
|
|
|
|
stored as an array of (userCol, rating) Rows.
|
2017-10-09 04:42:33 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("recommendForItemSubset", dataset, numUsers)
|
|
|
|
|
2015-05-08 20:24:32 -04:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import doctest
|
2016-02-11 18:50:33 -05:00
|
|
|
import pyspark.ml.recommendation
|
2016-05-23 21:14:48 -04:00
|
|
|
from pyspark.sql import SparkSession
|
2016-02-11 18:50:33 -05:00
|
|
|
globs = pyspark.ml.recommendation.__dict__.copy()
|
2015-05-08 20:24:32 -04:00
|
|
|
# The small batch size here ensures that we see multiple batches,
|
|
|
|
# even in these small test examples:
|
2016-05-23 21:14:48 -04:00
|
|
|
spark = SparkSession.builder\
|
|
|
|
.master("local[2]")\
|
|
|
|
.appName("ml.recommendation tests")\
|
|
|
|
.getOrCreate()
|
|
|
|
sc = spark.sparkContext
|
2015-05-08 20:24:32 -04:00
|
|
|
globs['sc'] = sc
|
2016-05-23 21:14:48 -04:00
|
|
|
globs['spark'] = spark
|
2016-02-20 04:07:19 -05:00
|
|
|
import tempfile
|
|
|
|
temp_path = tempfile.mkdtemp()
|
|
|
|
globs['temp_path'] = temp_path
|
|
|
|
try:
|
|
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
2016-05-23 21:14:48 -04:00
|
|
|
spark.stop()
|
2016-02-20 04:07:19 -05:00
|
|
|
finally:
|
|
|
|
from shutil import rmtree
|
|
|
|
try:
|
|
|
|
rmtree(temp_path)
|
|
|
|
except OSError:
|
|
|
|
pass
|
2015-05-08 20:24:32 -04:00
|
|
|
if failure_count:
|
2018-03-08 06:38:34 -05:00
|
|
|
sys.exit(-1)
|