spark-instrumented-optimizer/python/pyspark/ml/recommendation.py

650 lines
23 KiB
Python
Raw Normal View History

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from pyspark import since, keyword_only
from pyspark.ml.param.shared import HasPredictionCol, HasBlockSize, HasMaxIter, HasRegParam, \
HasCheckpointInterval, HasSeed
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.common import inherit_doc
from pyspark.ml.param import Params, TypeConverters, Param
from pyspark.ml.util import JavaMLWritable, JavaMLReadable
__all__ = ['ALS', 'ALSModel']
@inherit_doc
class _ALSModelParams(HasPredictionCol, HasBlockSize):
"""
Params for :py:class:`ALS` and :py:class:`ALSModel`.
.. versionadded:: 3.0.0
"""
userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
"the integer value range.", typeConverter=TypeConverters.toString)
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
"the integer value range.", typeConverter=TypeConverters.toString)
coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
"unknown or new users/items at prediction time. This may be useful " +
"in cross-validation or production scenarios, for handling " +
"user/item ids the model has not seen in the training data. " +
"Supported values: 'nan', 'drop'.",
typeConverter=TypeConverters.toString)
def __init__(self, *args):
super(_ALSModelParams, self).__init__(*args)
self._setDefault(blockSize=4096)
@since("1.4.0")
def getUserCol(self):
"""
Gets the value of userCol or its default value.
"""
return self.getOrDefault(self.userCol)
@since("1.4.0")
def getItemCol(self):
"""
Gets the value of itemCol or its default value.
"""
return self.getOrDefault(self.itemCol)
@since("2.2.0")
def getColdStartStrategy(self):
"""
Gets the value of coldStartStrategy or its default value.
"""
return self.getOrDefault(self.coldStartStrategy)
@inherit_doc
class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed):
"""
Params for :py:class:`ALS`.
.. versionadded:: 3.0.0
"""
rank = Param(Params._dummy(), "rank", "rank of the factorization",
typeConverter=TypeConverters.toInt)
numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks",
typeConverter=TypeConverters.toInt)
numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks",
typeConverter=TypeConverters.toInt)
implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference",
typeConverter=TypeConverters.toBoolean)
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
typeConverter=TypeConverters.toFloat)
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
typeConverter=TypeConverters.toString)
nonnegative = Param(Params._dummy(), "nonnegative",
"whether to use nonnegative constraint for least squares",
typeConverter=TypeConverters.toBoolean)
intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel",
"StorageLevel for intermediate datasets. Cannot be 'NONE'.",
typeConverter=TypeConverters.toString)
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
"StorageLevel for ALS model factors.",
typeConverter=TypeConverters.toString)
def __init__(self, *args):
super(_ALSParams, self).__init__(*args)
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
@since("1.4.0")
def getRank(self):
"""
Gets the value of rank or its default value.
"""
return self.getOrDefault(self.rank)
@since("1.4.0")
def getNumUserBlocks(self):
"""
Gets the value of numUserBlocks or its default value.
"""
return self.getOrDefault(self.numUserBlocks)
@since("1.4.0")
def getNumItemBlocks(self):
"""
Gets the value of numItemBlocks or its default value.
"""
return self.getOrDefault(self.numItemBlocks)
@since("1.4.0")
def getImplicitPrefs(self):
"""
Gets the value of implicitPrefs or its default value.
"""
return self.getOrDefault(self.implicitPrefs)
@since("1.4.0")
def getAlpha(self):
"""
Gets the value of alpha or its default value.
"""
return self.getOrDefault(self.alpha)
@since("1.4.0")
def getRatingCol(self):
"""
Gets the value of ratingCol or its default value.
"""
return self.getOrDefault(self.ratingCol)
@since("1.4.0")
def getNonnegative(self):
"""
Gets the value of nonnegative or its default value.
"""
return self.getOrDefault(self.nonnegative)
@since("2.0.0")
def getIntermediateStorageLevel(self):
"""
Gets the value of intermediateStorageLevel or its default value.
"""
return self.getOrDefault(self.intermediateStorageLevel)
@since("2.0.0")
def getFinalStorageLevel(self):
"""
Gets the value of finalStorageLevel or its default value.
"""
return self.getOrDefault(self.finalStorageLevel)
@inherit_doc
class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
"""
Alternating Least Squares (ALS) matrix factorization.
ALS attempts to estimate the ratings matrix `R` as the product of
two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically
these approximations are called 'factor' matrices. The general
approach is iterative. During each iteration, one of the factor
matrices is held constant, while the other is solved for using least
squares. The newly-solved factor matrix is then held constant while
solving for the other factor matrix.
This is a blocked implementation of the ALS factorization algorithm
that groups the two sets of factors (referred to as "users" and
"products") into blocks and reduces communication by only sending
one copy of each user vector to each product block on each
iteration, and only for the product blocks that need that user's
feature vector. This is achieved by pre-computing some information
about the ratings matrix to determine the "out-links" of each user
(which blocks of products it will contribute to) and "in-link"
information for each product (which of the feature vectors it
receives from each user block it will depend on). This allows us to
send only an array of feature vectors between each user block and
product block, and have the product block find the users' ratings
and update the products based on these messages.
For implicit preference data, the algorithm used is based on
`"Collaborative Filtering for Implicit Feedback Datasets",
<https://doi.org/10.1109/ICDM.2008.22>`_, adapted for the blocked
approach used here.
Essentially instead of finding the low-rank approximations to the
rating matrix `R`, this finds the approximations for a preference
matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0.
The ratings then act as 'confidence' values related to strength of
indicated user preferences rather than explicit ratings given to
items.
.. versionadded:: 1.4.0
[SPARK-28927][ML] Rethrow block mismatch exception in ALS when input data is nondeterministic ### What changes were proposed in this pull request? Fitting ALS model can be failed due to nondeterministic input data. Currently the failure is thrown by an ArrayIndexOutOfBoundsException which is not explainable for end users what is wrong in fitting. This patch catches this exception and rethrows a more explainable one, when the input data is nondeterministic. Because we may not exactly know the output deterministic level of RDDs produced by user code, this patch also adds a note to Scala/Python/R ALS document about the training data deterministic level. ### Why are the changes needed? ArrayIndexOutOfBoundsException was observed during fitting ALS model. It was caused by mismatching between in/out user/item blocks during computing ratings. If the training RDD output is nondeterministic, when fetch failure is happened, rerun part of training RDD can produce inconsistent user/item blocks. This patch is needed to notify users ALS fitting on nondeterministic input. ### Does this PR introduce any user-facing change? Yes. When fitting ALS model on nondeterministic input data, previously if rerun happens, users would see ArrayIndexOutOfBoundsException caused by mismatch between In/Out user/item blocks. After this patch, a SparkException with more clear message will be thrown, and original ArrayIndexOutOfBoundsException is wrapped. ### How was this patch tested? Tested on development cluster. Closes #25789 from viirya/als-indeterminate-input. Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com> Co-authored-by: Liang-Chi Hsieh <liangchi@uber.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
2019-09-18 10:22:13 -04:00
Notes
-----
The input rating dataframe to the ALS implementation should be deterministic.
Nondeterministic data can cause failure during fitting ALS model.
For example, an order-sensitive operation like sampling after a repartition makes
dataframe output nondeterministic, like `df.repartition(2).sample(False, 0.5, 1618)`.
Checkpointing sampled dataframe or adding a sort before sampling can help make the
dataframe deterministic.
Examples
--------
>>> df = spark.createDataFrame(
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
... ["user", "item", "rating"])
>>> als = ALS(rank=10, seed=0)
>>> als.setMaxIter(5)
ALS...
>>> als.getMaxIter()
5
>>> als.setRegParam(0.1)
ALS...
>>> als.getRegParam()
0.1
>>> als.clear(als.regParam)
>>> model = als.fit(df)
>>> model.getBlockSize()
4096
>>> model.getUserCol()
'user'
>>> model.setUserCol("user")
ALSModel...
>>> model.getItemCol()
'item'
>>> model.setPredictionCol("newPrediction")
ALS...
>>> model.rank
10
>>> model.userFactors.orderBy("id").collect()
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
[SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib ### What changes were proposed in this pull request? Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package. The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation. Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available. A table summarising which version gets loaded in which case: ``` | | BLAS.nativeBLAS | BLAS.javaBLAS | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires | | | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) | | | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS | | | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core | | | org.apache.spark.ml.linalg.BLAS | | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires | | | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) | | | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS | | | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | | wrapper for com.github.fommil:core | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | ``` ### Why are the changes needed? Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available. ### Does this PR introduce _any_ user-facing change? No, all changes are transparent to the user. ### How was this patch tested? The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite. [1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`: #### JDK8: ``` [info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 223 232 8 448.0 2.2 1.0X [info] java 221 228 7 453.0 2.2 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 122 128 4 821.2 1.2 1.0X [info] java 122 128 4 822.3 1.2 1.0X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 112 2 921.4 1.1 1.0X [info] java 70 74 3 1423.5 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.1 1.0 1.0X [info] java 47 49 2 2121.7 0.5 2.0X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 184 195 8 544.3 1.8 1.0X [info] java 185 196 7 539.5 1.9 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 99 104 4 1011.9 1.0 1.0X [info] java 99 104 4 1010.4 1.0 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 947.2 1.1 1.0X [info] java 0 0 0 1584.8 0.6 1.7X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 867.4 1.2 1.0X [info] java 1 1 0 865.0 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 485.9 2.1 1.0X [info] java 1 1 0 486.8 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1843.0 0.5 1.0X [info] java 0 0 0 2690.6 0.4 1.5X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1214.7 0.8 1.0X [info] java 0 0 0 2536.8 0.4 2.1X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1895.9 0.5 1.0X [info] java 0 0 0 2961.1 0.3 1.6X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1223.4 0.8 1.0X [info] java 0 0 0 3091.4 0.3 2.5X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 560 575 20 1787.1 0.6 1.0X [info] java 226 232 5 4432.4 0.2 2.5X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 570 586 23 1755.2 0.6 1.0X [info] java 227 232 4 4410.1 0.2 2.5X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 863 879 17 1158.4 0.9 1.0X [info] java 227 231 3 4407.9 0.2 3.8X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1282 1305 23 780.0 1.3 1.0X [info] java 227 232 4 4413.4 0.2 5.7X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 538 548 8 1858.6 0.5 1.0X [info] java 221 226 3 4521.1 0.2 2.4X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 549 558 10 1819.9 0.5 1.0X [info] java 222 229 7 4503.5 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 838 852 12 1193.0 0.8 1.0X [info] java 222 229 5 4500.5 0.2 3.8X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 905 919 18 1104.8 0.9 1.0X [info] java 221 228 5 4521.3 0.2 4.1X ``` #### JDK11: ``` [info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 195 204 10 512.7 2.0 1.0X [info] java 195 202 7 512.4 2.0 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 113 4 923.3 1.1 1.0X [info] java 102 107 4 984.4 1.0 1.1X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 107 110 3 938.1 1.1 1.0X [info] java 69 72 3 1447.1 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.5 1.0 1.0X [info] java 43 45 2 2317.1 0.4 2.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 155 168 8 644.2 1.6 1.0X [info] java 158 169 8 632.8 1.6 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 90 4 1178.1 0.8 1.0X [info] java 86 90 4 1167.7 0.9 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 1182.1 0.8 1.0X [info] java 0 0 0 1432.1 0.7 1.2X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 898.7 1.1 1.0X [info] java 1 1 0 891.5 1.1 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 495.4 2.0 1.0X [info] java 1 1 0 495.7 2.0 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2271.6 0.4 1.0X [info] java 0 0 0 3648.1 0.3 1.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1229.3 0.8 1.0X [info] java 0 0 0 2711.3 0.4 2.2X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2677.5 0.4 1.0X [info] java 0 0 0 3288.2 0.3 1.2X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1233.0 0.8 1.0X [info] java 0 0 0 2766.3 0.4 2.2X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 520 536 16 1923.6 0.5 1.0X [info] java 214 221 7 4669.5 0.2 2.4X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 593 612 17 1686.5 0.6 1.0X [info] java 215 219 3 4643.3 0.2 2.8X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 853 870 16 1172.8 0.9 1.0X [info] java 215 218 3 4659.7 0.2 4.0X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1350 1370 23 740.8 1.3 1.0X [info] java 215 219 4 4656.6 0.2 6.3X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 460 468 6 2173.2 0.5 1.0X [info] java 210 213 2 4752.7 0.2 2.2X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 535 544 8 1869.3 0.5 1.0X [info] java 210 215 5 4761.8 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 843 853 11 1186.8 0.8 1.0X [info] java 209 214 4 4793.4 0.2 4.0X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 891 904 15 1122.0 0.9 1.0X [info] java 209 214 4 4777.2 0.2 4.3X ``` #### JDK16: ``` [info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 194 199 7 515.7 1.9 1.0X [info] java 181 186 3 551.1 1.8 1.1X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 115 4 915.0 1.1 1.0X [info] java 88 92 3 1138.8 0.9 1.2X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 110 2 922.6 1.1 1.0X [info] java 54 56 2 1839.2 0.5 2.0X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 97 2 1046.1 1.0 1.0X [info] java 29 30 1 3393.4 0.3 3.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 156 165 5 643.0 1.6 1.0X [info] java 150 159 5 667.1 1.5 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 91 6 1171.0 0.9 1.0X [info] java 75 79 3 1340.6 0.7 1.1X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 917.0 1.1 1.0X [info] java 0 0 0 8147.2 0.1 8.9X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 859.3 1.2 1.0X [info] java 1 1 0 859.3 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 482.1 2.1 1.0X [info] java 1 1 0 482.6 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2214.2 0.5 1.0X [info] java 0 0 0 7975.8 0.1 3.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1231.4 0.8 1.0X [info] java 0 0 0 8680.9 0.1 7.0X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2684.3 0.4 1.0X [info] java 0 0 0 18527.1 0.1 6.9X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1235.4 0.8 1.0X [info] java 0 0 0 17347.9 0.1 14.0X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 530 552 18 1887.5 0.5 1.0X [info] java 58 64 3 17143.9 0.1 9.1X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 598 620 17 1671.1 0.6 1.0X [info] java 58 64 3 17196.6 0.1 10.3X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 834 847 14 1199.4 0.8 1.0X [info] java 57 63 4 17486.9 0.1 14.6X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1338 1366 22 747.3 1.3 1.0X [info] java 58 63 3 17356.6 0.1 23.2X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 489 501 9 2045.5 0.5 1.0X [info] java 36 38 2 27721.9 0.0 13.6X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 478 488 9 2094.0 0.5 1.0X [info] java 36 38 2 27813.2 0.0 13.3X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 825 837 10 1211.6 0.8 1.0X [info] java 35 38 2 28433.1 0.0 23.5X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 900 918 15 1111.6 0.9 1.0X [info] java 36 38 2 28073.0 0.0 25.3X ``` [2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas Closes #32253 from luhenry/master. Authored-by: Ludovic Henry <git@ludovic.dev> Signed-off-by: Sean Owen <srowen@gmail.com>
2021-04-27 15:00:59 -04:00
Row(user=0, item=2, newPrediction=0.69291...)
>>> predictions[1]
[SPARK-35150][ML] Accelerate fallback BLAS with dev.ludovic.netlib ### What changes were proposed in this pull request? Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package. The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation. Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available. A table summarising which version gets loaded in which case: ``` | | BLAS.nativeBLAS | BLAS.javaBLAS | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires | | | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) | | | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS | | | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core | | | org.apache.spark.ml.linalg.BLAS | | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires | | | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) | | | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS | | | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | | wrapper for com.github.fommil:core | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | ``` ### Why are the changes needed? Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available. ### Does this PR introduce _any_ user-facing change? No, all changes are transparent to the user. ### How was this patch tested? The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite. [1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`: #### JDK8: ``` [info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 223 232 8 448.0 2.2 1.0X [info] java 221 228 7 453.0 2.2 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 122 128 4 821.2 1.2 1.0X [info] java 122 128 4 822.3 1.2 1.0X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 112 2 921.4 1.1 1.0X [info] java 70 74 3 1423.5 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.1 1.0 1.0X [info] java 47 49 2 2121.7 0.5 2.0X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 184 195 8 544.3 1.8 1.0X [info] java 185 196 7 539.5 1.9 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 99 104 4 1011.9 1.0 1.0X [info] java 99 104 4 1010.4 1.0 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 947.2 1.1 1.0X [info] java 0 0 0 1584.8 0.6 1.7X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 867.4 1.2 1.0X [info] java 1 1 0 865.0 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 485.9 2.1 1.0X [info] java 1 1 0 486.8 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1843.0 0.5 1.0X [info] java 0 0 0 2690.6 0.4 1.5X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1214.7 0.8 1.0X [info] java 0 0 0 2536.8 0.4 2.1X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1895.9 0.5 1.0X [info] java 0 0 0 2961.1 0.3 1.6X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1223.4 0.8 1.0X [info] java 0 0 0 3091.4 0.3 2.5X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 560 575 20 1787.1 0.6 1.0X [info] java 226 232 5 4432.4 0.2 2.5X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 570 586 23 1755.2 0.6 1.0X [info] java 227 232 4 4410.1 0.2 2.5X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 863 879 17 1158.4 0.9 1.0X [info] java 227 231 3 4407.9 0.2 3.8X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1282 1305 23 780.0 1.3 1.0X [info] java 227 232 4 4413.4 0.2 5.7X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 538 548 8 1858.6 0.5 1.0X [info] java 221 226 3 4521.1 0.2 2.4X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 549 558 10 1819.9 0.5 1.0X [info] java 222 229 7 4503.5 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 838 852 12 1193.0 0.8 1.0X [info] java 222 229 5 4500.5 0.2 3.8X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 905 919 18 1104.8 0.9 1.0X [info] java 221 228 5 4521.3 0.2 4.1X ``` #### JDK11: ``` [info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 195 204 10 512.7 2.0 1.0X [info] java 195 202 7 512.4 2.0 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 113 4 923.3 1.1 1.0X [info] java 102 107 4 984.4 1.0 1.1X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 107 110 3 938.1 1.1 1.0X [info] java 69 72 3 1447.1 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.5 1.0 1.0X [info] java 43 45 2 2317.1 0.4 2.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 155 168 8 644.2 1.6 1.0X [info] java 158 169 8 632.8 1.6 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 90 4 1178.1 0.8 1.0X [info] java 86 90 4 1167.7 0.9 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 1182.1 0.8 1.0X [info] java 0 0 0 1432.1 0.7 1.2X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 898.7 1.1 1.0X [info] java 1 1 0 891.5 1.1 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 495.4 2.0 1.0X [info] java 1 1 0 495.7 2.0 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2271.6 0.4 1.0X [info] java 0 0 0 3648.1 0.3 1.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1229.3 0.8 1.0X [info] java 0 0 0 2711.3 0.4 2.2X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2677.5 0.4 1.0X [info] java 0 0 0 3288.2 0.3 1.2X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1233.0 0.8 1.0X [info] java 0 0 0 2766.3 0.4 2.2X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 520 536 16 1923.6 0.5 1.0X [info] java 214 221 7 4669.5 0.2 2.4X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 593 612 17 1686.5 0.6 1.0X [info] java 215 219 3 4643.3 0.2 2.8X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 853 870 16 1172.8 0.9 1.0X [info] java 215 218 3 4659.7 0.2 4.0X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1350 1370 23 740.8 1.3 1.0X [info] java 215 219 4 4656.6 0.2 6.3X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 460 468 6 2173.2 0.5 1.0X [info] java 210 213 2 4752.7 0.2 2.2X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 535 544 8 1869.3 0.5 1.0X [info] java 210 215 5 4761.8 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 843 853 11 1186.8 0.8 1.0X [info] java 209 214 4 4793.4 0.2 4.0X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 891 904 15 1122.0 0.9 1.0X [info] java 209 214 4 4777.2 0.2 4.3X ``` #### JDK16: ``` [info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 194 199 7 515.7 1.9 1.0X [info] java 181 186 3 551.1 1.8 1.1X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 115 4 915.0 1.1 1.0X [info] java 88 92 3 1138.8 0.9 1.2X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 110 2 922.6 1.1 1.0X [info] java 54 56 2 1839.2 0.5 2.0X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 97 2 1046.1 1.0 1.0X [info] java 29 30 1 3393.4 0.3 3.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 156 165 5 643.0 1.6 1.0X [info] java 150 159 5 667.1 1.5 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 91 6 1171.0 0.9 1.0X [info] java 75 79 3 1340.6 0.7 1.1X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 917.0 1.1 1.0X [info] java 0 0 0 8147.2 0.1 8.9X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 859.3 1.2 1.0X [info] java 1 1 0 859.3 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 482.1 2.1 1.0X [info] java 1 1 0 482.6 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2214.2 0.5 1.0X [info] java 0 0 0 7975.8 0.1 3.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1231.4 0.8 1.0X [info] java 0 0 0 8680.9 0.1 7.0X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2684.3 0.4 1.0X [info] java 0 0 0 18527.1 0.1 6.9X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1235.4 0.8 1.0X [info] java 0 0 0 17347.9 0.1 14.0X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 530 552 18 1887.5 0.5 1.0X [info] java 58 64 3 17143.9 0.1 9.1X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 598 620 17 1671.1 0.6 1.0X [info] java 58 64 3 17196.6 0.1 10.3X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 834 847 14 1199.4 0.8 1.0X [info] java 57 63 4 17486.9 0.1 14.6X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1338 1366 22 747.3 1.3 1.0X [info] java 58 63 3 17356.6 0.1 23.2X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 489 501 9 2045.5 0.5 1.0X [info] java 36 38 2 27721.9 0.0 13.6X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 478 488 9 2094.0 0.5 1.0X [info] java 36 38 2 27813.2 0.0 13.3X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 825 837 10 1211.6 0.8 1.0X [info] java 35 38 2 28433.1 0.0 23.5X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 900 918 15 1111.6 0.9 1.0X [info] java 36 38 2 28073.0 0.0 25.3X ``` [2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas Closes #32253 from luhenry/master. Authored-by: Ludovic Henry <git@ludovic.dev> Signed-off-by: Sean Owen <srowen@gmail.com>
2021-04-27 15:00:59 -04:00
Row(user=1, item=0, newPrediction=3.47356...)
>>> predictions[2]
Row(user=2, item=0, newPrediction=-0.899198...)
>>> user_recs = model.recommendForAllUsers(3)
>>> user_recs.where(user_recs.user == 0)\
.select("recommendations.item", "recommendations.rating").collect()
[Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]
>>> item_recs = model.recommendForAllItems(3)
>>> item_recs.where(item_recs.item == 2)\
.select("recommendations.user", "recommendations.rating").collect()
[Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]
>>> user_subset = df.where(df.user == 2)
>>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)
>>> user_subset_recs.select("recommendations.item", "recommendations.rating").first()
Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])
>>> item_subset = df.where(df.item == 0)
>>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)
>>> item_subset_recs.select("recommendations.user", "recommendations.rating").first()
Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])
>>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
>>> als.getMaxIter()
5
>>> model_path = temp_path + "/als_model"
>>> model.save(model_path)
>>> model2 = ALSModel.load(model_path)
>>> model.rank == model2.rank
True
>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
True
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
True
>>> model.transform(test).take(1) == model2.transform(test).take(1)
True
"""
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
2015-05-18 15:02:18 -04:00
@keyword_only
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods ### What changes were proposed in this pull request? This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/). __Note__: For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579 ### Why are the changes needed? Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that. In practice we could probably capture `locals` and drop `keyword_only` completel, i.e: ```python keyword_only def __init__(self, *, featuresCol="features"): ... kwargs = self._input_kwargs self.setParams(**kwargs) ``` could be replaced with ```python def __init__(self, *, featuresCol="features"): kwargs = locals() del kwargs["self"] ... self.setParams(**kwargs) ``` ### Does this PR introduce _any_ user-facing change? Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments. For example with ` LinearSVC` will change from ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__( self, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, ) Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2): File: /path/to/python/pyspark/ml/classification.py Type: function ``` to ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__ ( self, *, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1, ) Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1): File: ~/Workspace/spark/python/pyspark/ml/classification.py Type: function ``` ### How was this patch tested? Existing tests. Closes #29799 from zero323/SPARK-32933. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
def __init__(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
"""
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods ### What changes were proposed in this pull request? This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/). __Note__: For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579 ### Why are the changes needed? Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that. In practice we could probably capture `locals` and drop `keyword_only` completel, i.e: ```python keyword_only def __init__(self, *, featuresCol="features"): ... kwargs = self._input_kwargs self.setParams(**kwargs) ``` could be replaced with ```python def __init__(self, *, featuresCol="features"): kwargs = locals() del kwargs["self"] ... self.setParams(**kwargs) ``` ### Does this PR introduce _any_ user-facing change? Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments. For example with ` LinearSVC` will change from ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__( self, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, ) Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2): File: /path/to/python/pyspark/ml/classification.py Type: function ``` to ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__ ( self, *, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1, ) Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1): File: ~/Workspace/spark/python/pyspark/ml/classification.py Type: function ``` ### How was this patch tested? Existing tests. Closes #29799 from zero323/SPARK-32933. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
__init__(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
"""
super(ALS, self).__init__()
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes: 1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively. 2. Accept a list of param maps in `fit`. 3. Use parent uid and name to identify param. jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #6088 from mengxr/SPARK-7380 and squashes the following commits: 413c463 [Xiangrui Meng] remove unnecessary doc 4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 611c719 [Xiangrui Meng] fix python style 68862b8 [Xiangrui Meng] update _java_obj initialization 927ad19 [Xiangrui Meng] fix ml/tests.py 0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer 9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params 7e0d27f [Xiangrui Meng] merge master 46840fb [Xiangrui Meng] update wrappers b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap 46cb6ed [Xiangrui Meng] merge master a163413 [Xiangrui Meng] fix style 1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380 9630eae [Xiangrui Meng] fix Identifiable._randomUID 13bd70a [Xiangrui Meng] update ml/tests.py 64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl 02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python 66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui 7431272 [Joseph K. Bradley] Rebased with master
2015-05-18 15:02:18 -04:00
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods ### What changes were proposed in this pull request? This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/). __Note__: For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579 ### Why are the changes needed? Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that. In practice we could probably capture `locals` and drop `keyword_only` completel, i.e: ```python keyword_only def __init__(self, *, featuresCol="features"): ... kwargs = self._input_kwargs self.setParams(**kwargs) ``` could be replaced with ```python def __init__(self, *, featuresCol="features"): kwargs = locals() del kwargs["self"] ... self.setParams(**kwargs) ``` ### Does this PR introduce _any_ user-facing change? Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments. For example with ` LinearSVC` will change from ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__( self, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, ) Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2): File: /path/to/python/pyspark/ml/classification.py Type: function ``` to ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__ ( self, *, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1, ) Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1): File: ~/Workspace/spark/python/pyspark/ml/classification.py Type: function ``` ### How was this patch tested? Existing tests. Closes #29799 from zero323/SPARK-32933. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
def setParams(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
"""
[SPARK-32933][PYTHON] Use keyword-only syntax for keyword_only methods ### What changes were proposed in this pull request? This PR adjusts signatures of methods decorated with `keyword_only` to indicate using [Python 3 keyword-only syntax](https://www.python.org/dev/peps/pep-3102/). __Note__: For the moment the goal is not to replace `keyword_only`. For justification see https://github.com/apache/spark/pull/29591#discussion_r489402579 ### Why are the changes needed? Right now it is not clear that `keyword_only` methods are indeed keyword only. This proposal addresses that. In practice we could probably capture `locals` and drop `keyword_only` completel, i.e: ```python keyword_only def __init__(self, *, featuresCol="features"): ... kwargs = self._input_kwargs self.setParams(**kwargs) ``` could be replaced with ```python def __init__(self, *, featuresCol="features"): kwargs = locals() del kwargs["self"] ... self.setParams(**kwargs) ``` ### Does this PR introduce _any_ user-facing change? Docstrings and inspect tools will now indicate that `keyword_only` methods expect only keyword arguments. For example with ` LinearSVC` will change from ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__( self, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, ) Docstring: __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2): File: /path/to/python/pyspark/ml/classification.py Type: function ``` to ``` >>> from pyspark.ml.classification import LinearSVC >>> ?LinearSVC.__init__ Signature: LinearSVC.__init__ ( self, *, featuresCol='features', labelCol='label', predictionCol='prediction', maxIter=100, regParam=0.0, tol=1e-06, rawPredictionCol='rawPrediction', fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1, ) Docstring: __init__(self, \*, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, aggregationDepth=2, blockSize=1): File: ~/Workspace/spark/python/pyspark/ml/classification.py Type: function ``` ### How was this patch tested? Existing tests. Closes #29799 from zero323/SPARK-32933. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
2020-09-22 20:28:33 -04:00
setParams(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, \
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
Sets params for ALS.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return ALSModel(java_model)
@since("1.4.0")
def setRank(self, value):
"""
Sets the value of :py:attr:`rank`.
"""
return self._set(rank=value)
@since("1.4.0")
def setNumUserBlocks(self, value):
"""
Sets the value of :py:attr:`numUserBlocks`.
"""
return self._set(numUserBlocks=value)
@since("1.4.0")
def setNumItemBlocks(self, value):
"""
Sets the value of :py:attr:`numItemBlocks`.
"""
return self._set(numItemBlocks=value)
@since("1.4.0")
def setNumBlocks(self, value):
"""
Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
"""
self._set(numUserBlocks=value)
return self._set(numItemBlocks=value)
@since("1.4.0")
def setImplicitPrefs(self, value):
"""
Sets the value of :py:attr:`implicitPrefs`.
"""
return self._set(implicitPrefs=value)
@since("1.4.0")
def setAlpha(self, value):
"""
Sets the value of :py:attr:`alpha`.
"""
return self._set(alpha=value)
@since("1.4.0")
def setUserCol(self, value):
"""
Sets the value of :py:attr:`userCol`.
"""
return self._set(userCol=value)
@since("1.4.0")
def setItemCol(self, value):
"""
Sets the value of :py:attr:`itemCol`.
"""
return self._set(itemCol=value)
@since("1.4.0")
def setRatingCol(self, value):
"""
Sets the value of :py:attr:`ratingCol`.
"""
return self._set(ratingCol=value)
@since("1.4.0")
def setNonnegative(self, value):
"""
Sets the value of :py:attr:`nonnegative`.
"""
return self._set(nonnegative=value)
@since("2.0.0")
def setIntermediateStorageLevel(self, value):
"""
Sets the value of :py:attr:`intermediateStorageLevel`.
"""
return self._set(intermediateStorageLevel=value)
@since("2.0.0")
def setFinalStorageLevel(self, value):
"""
Sets the value of :py:attr:`finalStorageLevel`.
"""
return self._set(finalStorageLevel=value)
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation. The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible. The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)). ## How was this patch tested? New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12896 from MLnick/SPARK-14489-als-nan.
2017-02-28 09:17:35 -05:00
@since("2.2.0")
def setColdStartStrategy(self, value):
"""
Sets the value of :py:attr:`coldStartStrategy`.
"""
return self._set(coldStartStrategy=value)
def setMaxIter(self, value):
"""
Sets the value of :py:attr:`maxIter`.
"""
return self._set(maxIter=value)
def setRegParam(self, value):
"""
Sets the value of :py:attr:`regParam`.
"""
return self._set(regParam=value)
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)
def setCheckpointInterval(self, value):
"""
Sets the value of :py:attr:`checkpointInterval`.
"""
return self._set(checkpointInterval=value)
def setSeed(self, value):
"""
Sets the value of :py:attr:`seed`.
"""
return self._set(seed=value)
@since("3.0.0")
def setBlockSize(self, value):
"""
Sets the value of :py:attr:`blockSize`.
"""
return self._set(blockSize=value)
class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by ALS.
.. versionadded:: 1.4.0
"""
@since("3.0.0")
def setUserCol(self, value):
"""
Sets the value of :py:attr:`userCol`.
"""
return self._set(userCol=value)
@since("3.0.0")
def setItemCol(self, value):
"""
Sets the value of :py:attr:`itemCol`.
"""
return self._set(itemCol=value)
@since("3.0.0")
def setColdStartStrategy(self, value):
"""
Sets the value of :py:attr:`coldStartStrategy`.
"""
return self._set(coldStartStrategy=value)
@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)
@since("3.0.0")
def setBlockSize(self, value):
"""
Sets the value of :py:attr:`blockSize`.
"""
return self._set(blockSize=value)
@property
@since("1.4.0")
def rank(self):
"""rank of the matrix factorization model"""
return self._call_java("rank")
@property
@since("1.4.0")
def userFactors(self):
"""
a DataFrame that stores user factors in two columns: `id` and
`features`
"""
return self._call_java("userFactors")
@property
@since("1.4.0")
def itemFactors(self):
"""
a DataFrame that stores item factors in two columns: `id` and
`features`
"""
return self._call_java("itemFactors")
def recommendForAllUsers(self, numItems):
"""
Returns top `numItems` items recommended for each user, for all users.
.. versionadded:: 2.2.0
Parameters
----------
numItems : int
max number of recommendations for each user
Returns
-------
:py:class:`pyspark.sql.DataFrame`
a DataFrame of (userCol, recommendations), where recommendations are
stored as an array of (itemCol, rating) Rows.
"""
return self._call_java("recommendForAllUsers", numItems)
def recommendForAllItems(self, numUsers):
"""
Returns top `numUsers` users recommended for each item, for all items.
.. versionadded:: 2.2.0
Parameters
----------
numUsers : int
max number of recommendations for each item
Returns
-------
:py:class:`pyspark.sql.DataFrame`
a DataFrame of (itemCol, recommendations), where recommendations are
stored as an array of (userCol, rating) Rows.
"""
return self._call_java("recommendForAllItems", numUsers)
def recommendForUserSubset(self, dataset, numItems):
"""
Returns top `numItems` items recommended for each user id in the input data set. Note that
if there are duplicate ids in the input dataset, only one set of recommendations per unique
id will be returned.
.. versionadded:: 2.3.0
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
a DataFrame containing a column of user ids. The column name must match `userCol`.
numItems : int
max number of recommendations for each user
Returns
-------
:py:class:`pyspark.sql.DataFrame`
a DataFrame of (userCol, recommendations), where recommendations are
stored as an array of (itemCol, rating) Rows.
"""
return self._call_java("recommendForUserSubset", dataset, numItems)
def recommendForItemSubset(self, dataset, numUsers):
"""
Returns top `numUsers` users recommended for each item id in the input data set. Note that
if there are duplicate ids in the input dataset, only one set of recommendations per unique
id will be returned.
.. versionadded:: 2.3.0
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
a DataFrame containing a column of item ids. The column name must match `itemCol`.
numUsers : int
max number of recommendations for each item
Returns
-------
:py:class:`pyspark.sql.DataFrame`
a DataFrame of (itemCol, recommendations), where recommendations are
stored as an array of (userCol, rating) Rows.
"""
return self._call_java("recommendForItemSubset", dataset, numUsers)
if __name__ == "__main__":
import doctest
import pyspark.ml.recommendation
from pyspark.sql import SparkSession
globs = pyspark.ml.recommendation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
spark = SparkSession.builder\
.master("local[2]")\
.appName("ml.recommendation tests")\
.getOrCreate()
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
sys.exit(-1)