5b77ebb57b
### What changes were proposed in this pull request? Following https://github.com/apache/spark/pull/30810, I've continued looking for ways to accelerate the usage of BLAS in Spark. With this PR, I integrate work done in the [`dev.ludovic.netlib`](https://github.com/luhenry/netlib/) Maven package. The `dev.ludovic.netlib` library wraps the original `com.github.fommil.netlib` library and focus on accelerating the linear algebra routines in use in Spark. When running the `org.apache.spark.ml.linalg.BLASBenchmark` benchmarking suite, I get the results at [1] on an Intel machine. Moreover, this library is thoroughly tested to return the exact same results as the reference implementation. Under the hood, it reimplements the necessary algorithms in pure autovectorization-friendly Java 8, as well as takes advantage of the Vector API and Foreign Linker API introduced in JDK 16 when available. A table summarising which version gets loaded in which case: ``` | | BLAS.nativeBLAS | BLAS.javaBLAS | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | with -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.NetlibNativeBLAS, a | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | wrapper for com.github.fommil:all | (JDK16+, relies on the Vector API, requires | | | 2. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | `--add-modules=jdk.incubator.vector` on JDK16) | | | relies on the Foreign Linker API, requires | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | `--add-modules=jdk.incubator.foreign | 3. dev.ludovic.netlib.blas.JavaBLAS | | | -Dforeign.restricted=warn`) | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | 3. fails to load, falls back to BLAS.javaBLAS in | wrapper for com.github.fommil:core | | | org.apache.spark.ml.linalg.BLAS | | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | | without -Pnetlib-lgpl | 1. dev.ludovic.netlib.blas.ForeignBLAS (JDK16+, | 1. dev.ludovic.netlib.blas.VectorizedBLAS | | | relies on the Foreign Linker API, requires | (JDK16+, relies on the Vector API, requires | | | `--add-modules=jdk.incubator.foreign | `--add-modules=jdk.incubator.vector` on JDK16) | | | -Dforeign.restricted=warn`) | 2. dev.ludovic.netlib.blas.Java11BLAS (JDK11+) | | | 2. fails to load, falls back to BLAS.javaBLAS in | 3. dev.ludovic.netlib.blas.JavaBLAS | | | org.apache.spark.ml.linalg.BLAS | 4. dev.ludovic.netlib.blas.NetlibF2jBLAS, a | | | | wrapper for com.github.fommil:core | | --------------------- | -------------------------------------------------- | -------------------------------------------------- | ``` ### Why are the changes needed? Accelerates linear algebra operations when the pure-java fallback method is in use. Transparently falls back to native implementation (OpenBLAS, MKL) when available. ### Does this PR introduce _any_ user-facing change? No, all changes are transparent to the user. ### How was this patch tested? The `dev.ludovic.netlib` library has its own test suite [2]. It has also been validated by running the Spark test suite and benchmarking suite. [1] Results for `org.apache.spark.ml.linalg.BLASBenchmark`: #### JDK8: ``` [info] OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java8BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 223 232 8 448.0 2.2 1.0X [info] java 221 228 7 453.0 2.2 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 122 128 4 821.2 1.2 1.0X [info] java 122 128 4 822.3 1.2 1.0X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 112 2 921.4 1.1 1.0X [info] java 70 74 3 1423.5 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.1 1.0 1.0X [info] java 47 49 2 2121.7 0.5 2.0X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 184 195 8 544.3 1.8 1.0X [info] java 185 196 7 539.5 1.9 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 99 104 4 1011.9 1.0 1.0X [info] java 99 104 4 1010.4 1.0 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 947.2 1.1 1.0X [info] java 0 0 0 1584.8 0.6 1.7X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 867.4 1.2 1.0X [info] java 1 1 0 865.0 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 485.9 2.1 1.0X [info] java 1 1 0 486.8 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1843.0 0.5 1.0X [info] java 0 0 0 2690.6 0.4 1.5X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1214.7 0.8 1.0X [info] java 0 0 0 2536.8 0.4 2.1X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1895.9 0.5 1.0X [info] java 0 0 0 2961.1 0.3 1.6X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1223.4 0.8 1.0X [info] java 0 0 0 3091.4 0.3 2.5X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 560 575 20 1787.1 0.6 1.0X [info] java 226 232 5 4432.4 0.2 2.5X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 570 586 23 1755.2 0.6 1.0X [info] java 227 232 4 4410.1 0.2 2.5X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 863 879 17 1158.4 0.9 1.0X [info] java 227 231 3 4407.9 0.2 3.8X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1282 1305 23 780.0 1.3 1.0X [info] java 227 232 4 4413.4 0.2 5.7X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 538 548 8 1858.6 0.5 1.0X [info] java 221 226 3 4521.1 0.2 2.4X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 549 558 10 1819.9 0.5 1.0X [info] java 222 229 7 4503.5 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 838 852 12 1193.0 0.8 1.0X [info] java 222 229 5 4500.5 0.2 3.8X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 905 919 18 1104.8 0.9 1.0X [info] java 221 228 5 4521.3 0.2 4.1X ``` #### JDK11: ``` [info] OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] nativeBLAS = dev.ludovic.netlib.blas.Java11BLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 195 204 10 512.7 2.0 1.0X [info] java 195 202 7 512.4 2.0 1.0X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 113 4 923.3 1.1 1.0X [info] java 102 107 4 984.4 1.0 1.1X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 107 110 3 938.1 1.1 1.0X [info] java 69 72 3 1447.1 0.7 1.5X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 98 2 1046.5 1.0 1.0X [info] java 43 45 2 2317.1 0.4 2.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 155 168 8 644.2 1.6 1.0X [info] java 158 169 8 632.8 1.6 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 90 4 1178.1 0.8 1.0X [info] java 86 90 4 1167.7 0.9 1.0X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 1182.1 0.8 1.0X [info] java 0 0 0 1432.1 0.7 1.2X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 898.7 1.1 1.0X [info] java 1 1 0 891.5 1.1 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 495.4 2.0 1.0X [info] java 1 1 0 495.7 2.0 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2271.6 0.4 1.0X [info] java 0 0 0 3648.1 0.3 1.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1229.3 0.8 1.0X [info] java 0 0 0 2711.3 0.4 2.2X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2677.5 0.4 1.0X [info] java 0 0 0 3288.2 0.3 1.2X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1233.0 0.8 1.0X [info] java 0 0 0 2766.3 0.4 2.2X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 520 536 16 1923.6 0.5 1.0X [info] java 214 221 7 4669.5 0.2 2.4X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 593 612 17 1686.5 0.6 1.0X [info] java 215 219 3 4643.3 0.2 2.8X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 853 870 16 1172.8 0.9 1.0X [info] java 215 218 3 4659.7 0.2 4.0X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1350 1370 23 740.8 1.3 1.0X [info] java 215 219 4 4656.6 0.2 6.3X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 460 468 6 2173.2 0.5 1.0X [info] java 210 213 2 4752.7 0.2 2.2X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 535 544 8 1869.3 0.5 1.0X [info] java 210 215 5 4761.8 0.2 2.5X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 843 853 11 1186.8 0.8 1.0X [info] java 209 214 4 4793.4 0.2 4.0X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 891 904 15 1122.0 0.9 1.0X [info] java 209 214 4 4777.2 0.2 4.3X ``` #### JDK16: ``` [info] OpenJDK 64-Bit Server VM 16+36 on Linux 5.8.0-50-generic [info] Intel(R) Xeon(R) E-2276G CPU 3.80GHz [info] [info] f2jBLAS = dev.ludovic.netlib.blas.NetlibF2jBLAS [info] javaBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] nativeBLAS = dev.ludovic.netlib.blas.VectorizedBLAS [info] [info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 194 199 7 515.7 1.9 1.0X [info] java 181 186 3 551.1 1.8 1.1X [info] [info] saxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 109 115 4 915.0 1.1 1.0X [info] java 88 92 3 1138.8 0.9 1.2X [info] [info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 108 110 2 922.6 1.1 1.0X [info] java 54 56 2 1839.2 0.5 2.0X [info] [info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 96 97 2 1046.1 1.0 1.0X [info] java 29 30 1 3393.4 0.3 3.2X [info] [info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 156 165 5 643.0 1.6 1.0X [info] java 150 159 5 667.1 1.5 1.0X [info] [info] sscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 85 91 6 1171.0 0.9 1.0X [info] java 75 79 3 1340.6 0.7 1.1X [info] [info] dspmv[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 917.0 1.1 1.0X [info] java 0 0 0 8147.2 0.1 8.9X [info] [info] dspr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 859.3 1.2 1.0X [info] java 1 1 0 859.3 1.2 1.0X [info] [info] dsyr[U]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 482.1 2.1 1.0X [info] java 1 1 0 482.6 2.1 1.0X [info] [info] dgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2214.2 0.5 1.0X [info] java 0 0 0 7975.8 0.1 3.6X [info] [info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1231.4 0.8 1.0X [info] java 0 0 0 8680.9 0.1 7.0X [info] [info] sgemv[N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 0 0 0 2684.3 0.4 1.0X [info] java 0 0 0 18527.1 0.1 6.9X [info] [info] sgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1 1 0 1235.4 0.8 1.0X [info] java 0 0 0 17347.9 0.1 14.0X [info] [info] dgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 530 552 18 1887.5 0.5 1.0X [info] java 58 64 3 17143.9 0.1 9.1X [info] [info] dgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 598 620 17 1671.1 0.6 1.0X [info] java 58 64 3 17196.6 0.1 10.3X [info] [info] dgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 834 847 14 1199.4 0.8 1.0X [info] java 57 63 4 17486.9 0.1 14.6X [info] [info] dgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 1338 1366 22 747.3 1.3 1.0X [info] java 58 63 3 17356.6 0.1 23.2X [info] [info] sgemm[N,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 489 501 9 2045.5 0.5 1.0X [info] java 36 38 2 27721.9 0.0 13.6X [info] [info] sgemm[N,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 478 488 9 2094.0 0.5 1.0X [info] java 36 38 2 27813.2 0.0 13.3X [info] [info] sgemm[T,N]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 825 837 10 1211.6 0.8 1.0X [info] java 35 38 2 28433.1 0.0 23.5X [info] [info] sgemm[T,T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative [info] ------------------------------------------------------------------------------------------------------------------------ [info] f2j 900 918 15 1111.6 0.9 1.0X [info] java 36 38 2 28073.0 0.0 25.3X ``` [2] https://github.com/luhenry/netlib/tree/master/blas/src/test/java/dev/ludovic/netlib/blas Closes #32253 from luhenry/master. Authored-by: Ludovic Henry <git@ludovic.dev> Signed-off-by: Sean Owen <srowen@gmail.com>
650 lines
23 KiB
Python
650 lines
23 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import sys
|
|
|
|
from pyspark import since, keyword_only
|
|
from pyspark.ml.param.shared import HasPredictionCol, HasBlockSize, HasMaxIter, HasRegParam, \
|
|
HasCheckpointInterval, HasSeed
|
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel
|
|
from pyspark.ml.common import inherit_doc
|
|
from pyspark.ml.param import Params, TypeConverters, Param
|
|
from pyspark.ml.util import JavaMLWritable, JavaMLReadable
|
|
|
|
|
|
__all__ = ['ALS', 'ALSModel']
|
|
|
|
|
|
@inherit_doc
|
|
class _ALSModelParams(HasPredictionCol, HasBlockSize):
|
|
"""
|
|
Params for :py:class:`ALS` and :py:class:`ALSModel`.
|
|
|
|
.. versionadded:: 3.0.0
|
|
"""
|
|
|
|
userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " +
|
|
"the integer value range.", typeConverter=TypeConverters.toString)
|
|
itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " +
|
|
"the integer value range.", typeConverter=TypeConverters.toString)
|
|
coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
|
|
"unknown or new users/items at prediction time. This may be useful " +
|
|
"in cross-validation or production scenarios, for handling " +
|
|
"user/item ids the model has not seen in the training data. " +
|
|
"Supported values: 'nan', 'drop'.",
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
def __init__(self, *args):
|
|
super(_ALSModelParams, self).__init__(*args)
|
|
self._setDefault(blockSize=4096)
|
|
|
|
@since("1.4.0")
|
|
def getUserCol(self):
|
|
"""
|
|
Gets the value of userCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.userCol)
|
|
|
|
@since("1.4.0")
|
|
def getItemCol(self):
|
|
"""
|
|
Gets the value of itemCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.itemCol)
|
|
|
|
@since("2.2.0")
|
|
def getColdStartStrategy(self):
|
|
"""
|
|
Gets the value of coldStartStrategy or its default value.
|
|
"""
|
|
return self.getOrDefault(self.coldStartStrategy)
|
|
|
|
|
|
@inherit_doc
|
|
class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval, HasSeed):
|
|
"""
|
|
Params for :py:class:`ALS`.
|
|
|
|
.. versionadded:: 3.0.0
|
|
"""
|
|
|
|
rank = Param(Params._dummy(), "rank", "rank of the factorization",
|
|
typeConverter=TypeConverters.toInt)
|
|
numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks",
|
|
typeConverter=TypeConverters.toInt)
|
|
numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks",
|
|
typeConverter=TypeConverters.toInt)
|
|
implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference",
|
|
typeConverter=TypeConverters.toBoolean)
|
|
alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference",
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings",
|
|
typeConverter=TypeConverters.toString)
|
|
nonnegative = Param(Params._dummy(), "nonnegative",
|
|
"whether to use nonnegative constraint for least squares",
|
|
typeConverter=TypeConverters.toBoolean)
|
|
intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel",
|
|
"StorageLevel for intermediate datasets. Cannot be 'NONE'.",
|
|
typeConverter=TypeConverters.toString)
|
|
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
|
|
"StorageLevel for ALS model factors.",
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
def __init__(self, *args):
|
|
super(_ALSParams, self).__init__(*args)
|
|
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
|
|
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
|
|
|
|
@since("1.4.0")
|
|
def getRank(self):
|
|
"""
|
|
Gets the value of rank or its default value.
|
|
"""
|
|
return self.getOrDefault(self.rank)
|
|
|
|
@since("1.4.0")
|
|
def getNumUserBlocks(self):
|
|
"""
|
|
Gets the value of numUserBlocks or its default value.
|
|
"""
|
|
return self.getOrDefault(self.numUserBlocks)
|
|
|
|
@since("1.4.0")
|
|
def getNumItemBlocks(self):
|
|
"""
|
|
Gets the value of numItemBlocks or its default value.
|
|
"""
|
|
return self.getOrDefault(self.numItemBlocks)
|
|
|
|
@since("1.4.0")
|
|
def getImplicitPrefs(self):
|
|
"""
|
|
Gets the value of implicitPrefs or its default value.
|
|
"""
|
|
return self.getOrDefault(self.implicitPrefs)
|
|
|
|
@since("1.4.0")
|
|
def getAlpha(self):
|
|
"""
|
|
Gets the value of alpha or its default value.
|
|
"""
|
|
return self.getOrDefault(self.alpha)
|
|
|
|
@since("1.4.0")
|
|
def getRatingCol(self):
|
|
"""
|
|
Gets the value of ratingCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.ratingCol)
|
|
|
|
@since("1.4.0")
|
|
def getNonnegative(self):
|
|
"""
|
|
Gets the value of nonnegative or its default value.
|
|
"""
|
|
return self.getOrDefault(self.nonnegative)
|
|
|
|
@since("2.0.0")
|
|
def getIntermediateStorageLevel(self):
|
|
"""
|
|
Gets the value of intermediateStorageLevel or its default value.
|
|
"""
|
|
return self.getOrDefault(self.intermediateStorageLevel)
|
|
|
|
@since("2.0.0")
|
|
def getFinalStorageLevel(self):
|
|
"""
|
|
Gets the value of finalStorageLevel or its default value.
|
|
"""
|
|
return self.getOrDefault(self.finalStorageLevel)
|
|
|
|
|
|
@inherit_doc
|
|
class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Alternating Least Squares (ALS) matrix factorization.
|
|
|
|
ALS attempts to estimate the ratings matrix `R` as the product of
|
|
two lower-rank matrices, `X` and `Y`, i.e. `X * Yt = R`. Typically
|
|
these approximations are called 'factor' matrices. The general
|
|
approach is iterative. During each iteration, one of the factor
|
|
matrices is held constant, while the other is solved for using least
|
|
squares. The newly-solved factor matrix is then held constant while
|
|
solving for the other factor matrix.
|
|
|
|
This is a blocked implementation of the ALS factorization algorithm
|
|
that groups the two sets of factors (referred to as "users" and
|
|
"products") into blocks and reduces communication by only sending
|
|
one copy of each user vector to each product block on each
|
|
iteration, and only for the product blocks that need that user's
|
|
feature vector. This is achieved by pre-computing some information
|
|
about the ratings matrix to determine the "out-links" of each user
|
|
(which blocks of products it will contribute to) and "in-link"
|
|
information for each product (which of the feature vectors it
|
|
receives from each user block it will depend on). This allows us to
|
|
send only an array of feature vectors between each user block and
|
|
product block, and have the product block find the users' ratings
|
|
and update the products based on these messages.
|
|
|
|
For implicit preference data, the algorithm used is based on
|
|
`"Collaborative Filtering for Implicit Feedback Datasets",
|
|
<https://doi.org/10.1109/ICDM.2008.22>`_, adapted for the blocked
|
|
approach used here.
|
|
|
|
Essentially instead of finding the low-rank approximations to the
|
|
rating matrix `R`, this finds the approximations for a preference
|
|
matrix `P` where the elements of `P` are 1 if r > 0 and 0 if r <= 0.
|
|
The ratings then act as 'confidence' values related to strength of
|
|
indicated user preferences rather than explicit ratings given to
|
|
items.
|
|
|
|
.. versionadded:: 1.4.0
|
|
|
|
Notes
|
|
-----
|
|
The input rating dataframe to the ALS implementation should be deterministic.
|
|
Nondeterministic data can cause failure during fitting ALS model.
|
|
For example, an order-sensitive operation like sampling after a repartition makes
|
|
dataframe output nondeterministic, like `df.repartition(2).sample(False, 0.5, 1618)`.
|
|
Checkpointing sampled dataframe or adding a sort before sampling can help make the
|
|
dataframe deterministic.
|
|
|
|
Examples
|
|
--------
|
|
>>> df = spark.createDataFrame(
|
|
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
|
|
... ["user", "item", "rating"])
|
|
>>> als = ALS(rank=10, seed=0)
|
|
>>> als.setMaxIter(5)
|
|
ALS...
|
|
>>> als.getMaxIter()
|
|
5
|
|
>>> als.setRegParam(0.1)
|
|
ALS...
|
|
>>> als.getRegParam()
|
|
0.1
|
|
>>> als.clear(als.regParam)
|
|
>>> model = als.fit(df)
|
|
>>> model.getBlockSize()
|
|
4096
|
|
>>> model.getUserCol()
|
|
'user'
|
|
>>> model.setUserCol("user")
|
|
ALSModel...
|
|
>>> model.getItemCol()
|
|
'item'
|
|
>>> model.setPredictionCol("newPrediction")
|
|
ALS...
|
|
>>> model.rank
|
|
10
|
|
>>> model.userFactors.orderBy("id").collect()
|
|
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
|
|
>>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
|
|
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
|
|
>>> predictions[0]
|
|
Row(user=0, item=2, newPrediction=0.69291...)
|
|
>>> predictions[1]
|
|
Row(user=1, item=0, newPrediction=3.47356...)
|
|
>>> predictions[2]
|
|
Row(user=2, item=0, newPrediction=-0.899198...)
|
|
>>> user_recs = model.recommendForAllUsers(3)
|
|
>>> user_recs.where(user_recs.user == 0)\
|
|
.select("recommendations.item", "recommendations.rating").collect()
|
|
[Row(item=[0, 1, 2], rating=[3.910..., 1.997..., 0.692...])]
|
|
>>> item_recs = model.recommendForAllItems(3)
|
|
>>> item_recs.where(item_recs.item == 2)\
|
|
.select("recommendations.user", "recommendations.rating").collect()
|
|
[Row(user=[2, 1, 0], rating=[4.892..., 3.991..., 0.692...])]
|
|
>>> user_subset = df.where(df.user == 2)
|
|
>>> user_subset_recs = model.recommendForUserSubset(user_subset, 3)
|
|
>>> user_subset_recs.select("recommendations.item", "recommendations.rating").first()
|
|
Row(item=[2, 1, 0], rating=[4.892..., 1.076..., -0.899...])
|
|
>>> item_subset = df.where(df.item == 0)
|
|
>>> item_subset_recs = model.recommendForItemSubset(item_subset, 3)
|
|
>>> item_subset_recs.select("recommendations.user", "recommendations.rating").first()
|
|
Row(user=[0, 1, 2], rating=[3.910..., 3.473..., -0.899...])
|
|
>>> als_path = temp_path + "/als"
|
|
>>> als.save(als_path)
|
|
>>> als2 = ALS.load(als_path)
|
|
>>> als.getMaxIter()
|
|
5
|
|
>>> model_path = temp_path + "/als_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = ALSModel.load(model_path)
|
|
>>> model.rank == model2.rank
|
|
True
|
|
>>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect())
|
|
True
|
|
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
|
|
True
|
|
>>> model.transform(test).take(1) == model2.transform(test).take(1)
|
|
True
|
|
"""
|
|
|
|
@keyword_only
|
|
def __init__(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
|
|
"""
|
|
__init__(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
|
|
intermediateStorageLevel="MEMORY_AND_DISK", \
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
|
|
"""
|
|
super(ALS, self).__init__()
|
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.4.0")
|
|
def setParams(self, *, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
|
|
intermediateStorageLevel="MEMORY_AND_DISK",
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096):
|
|
"""
|
|
setParams(self, \\*, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, \
|
|
numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", \
|
|
seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, \
|
|
intermediateStorageLevel="MEMORY_AND_DISK", \
|
|
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
|
|
Sets params for ALS.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return ALSModel(java_model)
|
|
|
|
@since("1.4.0")
|
|
def setRank(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`rank`.
|
|
"""
|
|
return self._set(rank=value)
|
|
|
|
@since("1.4.0")
|
|
def setNumUserBlocks(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`numUserBlocks`.
|
|
"""
|
|
return self._set(numUserBlocks=value)
|
|
|
|
@since("1.4.0")
|
|
def setNumItemBlocks(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`numItemBlocks`.
|
|
"""
|
|
return self._set(numItemBlocks=value)
|
|
|
|
@since("1.4.0")
|
|
def setNumBlocks(self, value):
|
|
"""
|
|
Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value.
|
|
"""
|
|
self._set(numUserBlocks=value)
|
|
return self._set(numItemBlocks=value)
|
|
|
|
@since("1.4.0")
|
|
def setImplicitPrefs(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`implicitPrefs`.
|
|
"""
|
|
return self._set(implicitPrefs=value)
|
|
|
|
@since("1.4.0")
|
|
def setAlpha(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`alpha`.
|
|
"""
|
|
return self._set(alpha=value)
|
|
|
|
@since("1.4.0")
|
|
def setUserCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`userCol`.
|
|
"""
|
|
return self._set(userCol=value)
|
|
|
|
@since("1.4.0")
|
|
def setItemCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`itemCol`.
|
|
"""
|
|
return self._set(itemCol=value)
|
|
|
|
@since("1.4.0")
|
|
def setRatingCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`ratingCol`.
|
|
"""
|
|
return self._set(ratingCol=value)
|
|
|
|
@since("1.4.0")
|
|
def setNonnegative(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`nonnegative`.
|
|
"""
|
|
return self._set(nonnegative=value)
|
|
|
|
@since("2.0.0")
|
|
def setIntermediateStorageLevel(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`intermediateStorageLevel`.
|
|
"""
|
|
return self._set(intermediateStorageLevel=value)
|
|
|
|
@since("2.0.0")
|
|
def setFinalStorageLevel(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`finalStorageLevel`.
|
|
"""
|
|
return self._set(finalStorageLevel=value)
|
|
|
|
@since("2.2.0")
|
|
def setColdStartStrategy(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`coldStartStrategy`.
|
|
"""
|
|
return self._set(coldStartStrategy=value)
|
|
|
|
def setMaxIter(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`maxIter`.
|
|
"""
|
|
return self._set(maxIter=value)
|
|
|
|
def setRegParam(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`regParam`.
|
|
"""
|
|
return self._set(regParam=value)
|
|
|
|
def setPredictionCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`predictionCol`.
|
|
"""
|
|
return self._set(predictionCol=value)
|
|
|
|
def setCheckpointInterval(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`checkpointInterval`.
|
|
"""
|
|
return self._set(checkpointInterval=value)
|
|
|
|
def setSeed(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`seed`.
|
|
"""
|
|
return self._set(seed=value)
|
|
|
|
@since("3.0.0")
|
|
def setBlockSize(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`blockSize`.
|
|
"""
|
|
return self._set(blockSize=value)
|
|
|
|
|
|
class ALSModel(JavaModel, _ALSModelParams, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Model fitted by ALS.
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@since("3.0.0")
|
|
def setUserCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`userCol`.
|
|
"""
|
|
return self._set(userCol=value)
|
|
|
|
@since("3.0.0")
|
|
def setItemCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`itemCol`.
|
|
"""
|
|
return self._set(itemCol=value)
|
|
|
|
@since("3.0.0")
|
|
def setColdStartStrategy(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`coldStartStrategy`.
|
|
"""
|
|
return self._set(coldStartStrategy=value)
|
|
|
|
@since("3.0.0")
|
|
def setPredictionCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`predictionCol`.
|
|
"""
|
|
return self._set(predictionCol=value)
|
|
|
|
@since("3.0.0")
|
|
def setBlockSize(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`blockSize`.
|
|
"""
|
|
return self._set(blockSize=value)
|
|
|
|
@property
|
|
@since("1.4.0")
|
|
def rank(self):
|
|
"""rank of the matrix factorization model"""
|
|
return self._call_java("rank")
|
|
|
|
@property
|
|
@since("1.4.0")
|
|
def userFactors(self):
|
|
"""
|
|
a DataFrame that stores user factors in two columns: `id` and
|
|
`features`
|
|
"""
|
|
return self._call_java("userFactors")
|
|
|
|
@property
|
|
@since("1.4.0")
|
|
def itemFactors(self):
|
|
"""
|
|
a DataFrame that stores item factors in two columns: `id` and
|
|
`features`
|
|
"""
|
|
return self._call_java("itemFactors")
|
|
|
|
def recommendForAllUsers(self, numItems):
|
|
"""
|
|
Returns top `numItems` items recommended for each user, for all users.
|
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
Parameters
|
|
----------
|
|
numItems : int
|
|
max number of recommendations for each user
|
|
|
|
Returns
|
|
-------
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame of (userCol, recommendations), where recommendations are
|
|
stored as an array of (itemCol, rating) Rows.
|
|
"""
|
|
return self._call_java("recommendForAllUsers", numItems)
|
|
|
|
def recommendForAllItems(self, numUsers):
|
|
"""
|
|
Returns top `numUsers` users recommended for each item, for all items.
|
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
Parameters
|
|
----------
|
|
numUsers : int
|
|
max number of recommendations for each item
|
|
|
|
Returns
|
|
-------
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame of (itemCol, recommendations), where recommendations are
|
|
stored as an array of (userCol, rating) Rows.
|
|
"""
|
|
return self._call_java("recommendForAllItems", numUsers)
|
|
|
|
def recommendForUserSubset(self, dataset, numItems):
|
|
"""
|
|
Returns top `numItems` items recommended for each user id in the input data set. Note that
|
|
if there are duplicate ids in the input dataset, only one set of recommendations per unique
|
|
id will be returned.
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
Parameters
|
|
----------
|
|
dataset : :py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame containing a column of user ids. The column name must match `userCol`.
|
|
numItems : int
|
|
max number of recommendations for each user
|
|
|
|
Returns
|
|
-------
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame of (userCol, recommendations), where recommendations are
|
|
stored as an array of (itemCol, rating) Rows.
|
|
"""
|
|
return self._call_java("recommendForUserSubset", dataset, numItems)
|
|
|
|
def recommendForItemSubset(self, dataset, numUsers):
|
|
"""
|
|
Returns top `numUsers` users recommended for each item id in the input data set. Note that
|
|
if there are duplicate ids in the input dataset, only one set of recommendations per unique
|
|
id will be returned.
|
|
|
|
.. versionadded:: 2.3.0
|
|
|
|
Parameters
|
|
----------
|
|
dataset : :py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame containing a column of item ids. The column name must match `itemCol`.
|
|
numUsers : int
|
|
max number of recommendations for each item
|
|
|
|
Returns
|
|
-------
|
|
:py:class:`pyspark.sql.DataFrame`
|
|
a DataFrame of (itemCol, recommendations), where recommendations are
|
|
stored as an array of (userCol, rating) Rows.
|
|
"""
|
|
return self._call_java("recommendForItemSubset", dataset, numUsers)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import doctest
|
|
import pyspark.ml.recommendation
|
|
from pyspark.sql import SparkSession
|
|
globs = pyspark.ml.recommendation.__dict__.copy()
|
|
# The small batch size here ensures that we see multiple batches,
|
|
# even in these small test examples:
|
|
spark = SparkSession.builder\
|
|
.master("local[2]")\
|
|
.appName("ml.recommendation tests")\
|
|
.getOrCreate()
|
|
sc = spark.sparkContext
|
|
globs['sc'] = sc
|
|
globs['spark'] = spark
|
|
import tempfile
|
|
temp_path = tempfile.mkdtemp()
|
|
globs['temp_path'] = temp_path
|
|
try:
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
|
spark.stop()
|
|
finally:
|
|
from shutil import rmtree
|
|
try:
|
|
rmtree(temp_path)
|
|
except OSError:
|
|
pass
|
|
if failure_count:
|
|
sys.exit(-1)
|