8ec25cd67e
## What changes were proposed in this pull request? Fixing typos is sometimes very hard. It's not so easy to visually review them. Recently, I discovered a very useful tool for it, [misspell](https://github.com/client9/misspell). This pull request fixes minor typos detected by [misspell](https://github.com/client9/misspell) except for the false positives. If you would like me to work on other files as well, let me know. ## How was this patch tested? ### before ``` $ misspell . | grep -v '.js' R/pkg/R/SQLContext.R:354:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:424:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:445:43: "definiton" is a misspelling of "definition" R/pkg/R/SQLContext.R:495:43: "definiton" is a misspelling of "definition" NOTICE-binary:454:16: "containd" is a misspelling of "contained" R/pkg/R/context.R:46:43: "definiton" is a misspelling of "definition" R/pkg/R/context.R:74:43: "definiton" is a misspelling of "definition" R/pkg/R/DataFrame.R:591:48: "persistance" is a misspelling of "persistence" R/pkg/R/streaming.R:166:44: "occured" is a misspelling of "occurred" R/pkg/inst/worker/worker.R:65:22: "ouput" is a misspelling of "output" R/pkg/tests/fulltests/test_utils.R:106:25: "environemnt" is a misspelling of "environment" common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java:38:39: "existant" is a misspelling of "existent" common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java:83:39: "existant" is a misspelling of "existent" common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java:243:46: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:234:19: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:238:63: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:244:46: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java:276:39: "transfered" is a misspelling of "transferred" common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java:27:20: "transfered" is a misspelling of "transferred" common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala:195:15: "orgin" is a misspelling of "origin" core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala:621:39: "gauranteed" is a misspelling of "guaranteed" core/src/main/scala/org/apache/spark/status/storeTypes.scala:113:29: "ect" is a misspelling of "etc" core/src/main/scala/org/apache/spark/storage/DiskStore.scala:282:18: "transfered" is a misspelling of "transferred" core/src/main/scala/org/apache/spark/util/ListenerBus.scala:64:17: "overriden" is a misspelling of "overridden" core/src/test/scala/org/apache/spark/ShuffleSuite.scala:211:7: "substracted" is a misspelling of "subtracted" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:1922:49: "agriculteur" is a misspelling of "agriculture" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:2468:84: "truely" is a misspelling of "truly" core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala:25:18: "persistance" is a misspelling of "persistence" core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala:26:69: "persistance" is a misspelling of "persistence" data/streaming/AFINN-111.txt:1219:0: "humerous" is a misspelling of "humorous" dev/run-pip-tests:55:28: "enviroments" is a misspelling of "environments" dev/run-pip-tests:91:37: "virutal" is a misspelling of "virtual" dev/merge_spark_pr.py:377:72: "accross" is a misspelling of "across" dev/merge_spark_pr.py:378:66: "accross" is a misspelling of "across" dev/run-pip-tests:126:25: "enviroments" is a misspelling of "environments" docs/configuration.md:1830:82: "overriden" is a misspelling of "overridden" docs/structured-streaming-programming-guide.md:525:45: "processs" is a misspelling of "processes" docs/structured-streaming-programming-guide.md:1165:61: "BETWEN" is a misspelling of "BETWEEN" docs/sql-programming-guide.md:1891:810: "behaivor" is a misspelling of "behavior" examples/src/main/python/sql/arrow.py:98:8: "substract" is a misspelling of "subtract" examples/src/main/python/sql/arrow.py:103:27: "substract" is a misspelling of "subtract" licenses/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt:170:0: "teh" is a misspelling of "the" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt:53:0: "eles" is a misspelling of "eels" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:99:20: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:539:11: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala:77:36: "Teh" is a misspelling of "The" mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala:230:24: "inital" is a misspelling of "initial" mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala:276:9: "Euclidian" is a misspelling of "Euclidean" mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala:237:26: "descripiton" is a misspelling of "descriptions" python/pyspark/find_spark_home.py:30:13: "enviroment" is a misspelling of "environment" python/pyspark/context.py:937:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:938:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:939:12: "supress" is a misspelling of "suppress" python/pyspark/context.py:940:12: "supress" is a misspelling of "suppress" python/pyspark/heapq3.py:6:63: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:7:2: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:263:29: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:263:39: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:270:49: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:270:59: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:275:2: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:275:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:277:29: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:277:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:713:8: "probabilty" is a misspelling of "probability" python/pyspark/ml/clustering.py:1038:8: "Currenlty" is a misspelling of "Currently" python/pyspark/ml/stat.py:339:23: "Euclidian" is a misspelling of "Euclidean" python/pyspark/ml/regression.py:1378:20: "paramter" is a misspelling of "parameter" python/pyspark/mllib/stat/_statistics.py:262:8: "probabilty" is a misspelling of "probability" python/pyspark/rdd.py:1363:32: "paramter" is a misspelling of "parameter" python/pyspark/streaming/tests.py:825:42: "retuns" is a misspelling of "returns" python/pyspark/sql/tests.py:768:29: "initalization" is a misspelling of "initialization" python/pyspark/sql/tests.py:3616:31: "initalize" is a misspelling of "initialize" resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala:120:39: "arbitary" is a misspelling of "arbitrary" resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala:26:45: "sucessfully" is a misspelling of "successfully" resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala:358:27: "constaints" is a misspelling of "constraints" resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala:111:24: "senstive" is a misspelling of "sensitive" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala:1063:5: "overwirte" is a misspelling of "overwrite" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala:1348:17: "compatability" is a misspelling of "compatibility" sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala:77:36: "paramter" is a misspelling of "parameter" sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:1374:22: "precendence" is a misspelling of "precedence" sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala:238:27: "unnecassary" is a misspelling of "unnecessary" sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala:212:17: "whn" is a misspelling of "when" sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala:147:60: "timestmap" is a misspelling of "timestamp" sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala:150:45: "precentage" is a misspelling of "percentage" sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala:135:29: "infered" is a misspelling of "inferred" sql/hive/src/test/resources/golden/udf_instr-1-2e76f819563dbaba4beb51e3a130b922:1:52: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_instr-2-32da357fc754badd6e3898dcc8989182:1:52: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_locate-1-6e41693c9c6dceea4d7fab4c02884e4e:1:63: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_locate-2-d9b5934457931447874d6bb7c13de478:1:63: "occurance" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8:9:79: "occurence" is a misspelling of "occurrence" sql/hive/src/test/resources/golden/udf_translate-2-f7aa38a33ca0df73b7a1e6b6da4b7fe8:13:110: "occurence" is a misspelling of "occurrence" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/annotate_stats_join.q:46:105: "distint" is a misspelling of "distinct" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/auto_sortmerge_join_11.q:29:3: "Currenly" is a misspelling of "Currently" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/avro_partitioned.q:72:15: "existant" is a misspelling of "existent" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/decimal_udf.q:25:3: "substraction" is a misspelling of "subtraction" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby2_map_multi_distinct.q:16:51: "funtion" is a misspelling of "function" sql/hive/src/test/resources/ql/src/test/queries/clientpositive/groupby_sort_8.q:15:30: "issueing" is a misspelling of "issuing" sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala:669:52: "wiht" is a misspelling of "with" sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java:474:9: "Refering" is a misspelling of "Referring" ``` ### after ``` $ misspell . | grep -v '.js' common/network-common/src/main/java/org/apache/spark/network/util/AbstractFileRegion.java:27:20: "transfered" is a misspelling of "transferred" core/src/main/scala/org/apache/spark/status/storeTypes.scala:113:29: "ect" is a misspelling of "etc" core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:1922:49: "agriculteur" is a misspelling of "agriculture" data/streaming/AFINN-111.txt:1219:0: "humerous" is a misspelling of "humorous" licenses/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:5:63: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:6:2: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:262:29: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:262:39: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:269:49: "Stichting" is a misspelling of "Stitching" licenses-binary/LICENSE-heapq.txt:269:59: "Mathematisch" is a misspelling of "Mathematics" licenses-binary/LICENSE-heapq.txt:274:2: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:274:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" licenses-binary/LICENSE-heapq.txt:276:29: "STICHTING" is a misspelling of "STITCHING" licenses-binary/LICENSE-heapq.txt:276:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt:170:0: "teh" is a misspelling of "the" mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt:53:0: "eles" is a misspelling of "eels" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:99:20: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala:539:11: "Euclidian" is a misspelling of "Euclidean" mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala:77:36: "Teh" is a misspelling of "The" mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala:276:9: "Euclidian" is a misspelling of "Euclidean" python/pyspark/heapq3.py:6:63: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:7:2: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:263:29: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:263:39: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:270:49: "Stichting" is a misspelling of "Stitching" python/pyspark/heapq3.py:270:59: "Mathematisch" is a misspelling of "Mathematics" python/pyspark/heapq3.py:275:2: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:275:12: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/heapq3.py:277:29: "STICHTING" is a misspelling of "STITCHING" python/pyspark/heapq3.py:277:39: "MATHEMATISCH" is a misspelling of "MATHEMATICS" python/pyspark/ml/stat.py:339:23: "Euclidian" is a misspelling of "Euclidean" ``` Closes #22070 from seratch/fix-typo. Authored-by: Kazuhiro Sera <seratch@gmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
1875 lines
65 KiB
Python
1875 lines
65 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import sys
|
|
import warnings
|
|
|
|
from pyspark import since, keyword_only
|
|
from pyspark.ml.param.shared import *
|
|
from pyspark.ml.util import *
|
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
|
|
from pyspark.ml.common import inherit_doc
|
|
from pyspark.sql import DataFrame
|
|
|
|
|
|
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
|
|
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
|
|
'GBTRegressor', 'GBTRegressionModel',
|
|
'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel',
|
|
'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary',
|
|
'IsotonicRegression', 'IsotonicRegressionModel',
|
|
'LinearRegression', 'LinearRegressionModel',
|
|
'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
|
|
'RandomForestRegressor', 'RandomForestRegressionModel']
|
|
|
|
|
|
@inherit_doc
|
|
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
|
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
|
|
HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth, HasLoss,
|
|
JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Linear regression.
|
|
|
|
The learning objective is to minimize the specified loss function, with regularization.
|
|
This supports two kinds of loss:
|
|
|
|
* squaredError (a.k.a squared loss)
|
|
* huber (a hybrid of squared error for relatively small errors and absolute error for \
|
|
relatively large ones, and we estimate the scale parameter from training data)
|
|
|
|
This supports multiple types of regularization:
|
|
|
|
* none (a.k.a. ordinary least squares)
|
|
* L2 (ridge regression)
|
|
* L1 (Lasso)
|
|
* L2 + L1 (elastic net)
|
|
|
|
Note: Fitting with huber loss only supports none and L2 regularization.
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, 2.0, Vectors.dense(1.0)),
|
|
... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
|
|
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight")
|
|
>>> model = lr.fit(df)
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
>>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
|
|
True
|
|
>>> abs(model.coefficients[0] - 1.0) < 0.001
|
|
True
|
|
>>> abs(model.intercept - 0.0) < 0.001
|
|
True
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
>>> abs(model.transform(test1).head().prediction - 1.0) < 0.001
|
|
True
|
|
>>> lr.setParams("vector")
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: Method setParams forces keyword arguments.
|
|
>>> lr_path = temp_path + "/lr"
|
|
>>> lr.save(lr_path)
|
|
>>> lr2 = LinearRegression.load(lr_path)
|
|
>>> lr2.getMaxIter()
|
|
5
|
|
>>> model_path = temp_path + "/lr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = LinearRegressionModel.load(model_path)
|
|
>>> model.coefficients[0] == model2.coefficients[0]
|
|
True
|
|
>>> model.intercept == model2.intercept
|
|
True
|
|
>>> model.numFeatures
|
|
1
|
|
>>> model.write().format("pmml").save(model_path + "_2")
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
|
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
|
|
|
|
loss = Param(Params._dummy(), "loss", "The loss function to be optimized. Supported " +
|
|
"options: squaredError, huber.", typeConverter=TypeConverters.toString)
|
|
|
|
epsilon = Param(Params._dummy(), "epsilon", "The shape parameter to control the amount of " +
|
|
"robustness. Must be > 1.0. Only valid when loss is huber",
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
|
standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
|
|
loss="squaredError", epsilon=1.35):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
|
standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
|
|
loss="squaredError", epsilon=1.35)
|
|
"""
|
|
super(LinearRegression, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.LinearRegression", self.uid)
|
|
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35)
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.4.0")
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
|
standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
|
|
loss="squaredError", epsilon=1.35):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
|
standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
|
|
loss="squaredError", epsilon=1.35)
|
|
Sets params for linear regression.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return LinearRegressionModel(java_model)
|
|
|
|
@since("2.3.0")
|
|
def setEpsilon(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`epsilon`.
|
|
"""
|
|
return self._set(epsilon=value)
|
|
|
|
@since("2.3.0")
|
|
def getEpsilon(self):
|
|
"""
|
|
Gets the value of epsilon or its default value.
|
|
"""
|
|
return self.getOrDefault(self.epsilon)
|
|
|
|
|
|
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Model fitted by :class:`LinearRegression`.
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def coefficients(self):
|
|
"""
|
|
Model coefficients.
|
|
"""
|
|
return self._call_java("coefficients")
|
|
|
|
@property
|
|
@since("1.4.0")
|
|
def intercept(self):
|
|
"""
|
|
Model intercept.
|
|
"""
|
|
return self._call_java("intercept")
|
|
|
|
@property
|
|
@since("2.3.0")
|
|
def scale(self):
|
|
"""
|
|
The value by which \|y - X'w\| is scaled down when loss is "huber", otherwise 1.0.
|
|
"""
|
|
return self._call_java("scale")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def summary(self):
|
|
"""
|
|
Gets summary (e.g. residuals, mse, r-squared ) of model on
|
|
training set. An exception is thrown if
|
|
`trainingSummary is None`.
|
|
"""
|
|
if self.hasSummary:
|
|
java_lrt_summary = self._call_java("summary")
|
|
return LinearRegressionTrainingSummary(java_lrt_summary)
|
|
else:
|
|
raise RuntimeError("No training summary available for this %s" %
|
|
self.__class__.__name__)
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def hasSummary(self):
|
|
"""
|
|
Indicates whether a training summary exists for this model
|
|
instance.
|
|
"""
|
|
return self._call_java("hasSummary")
|
|
|
|
@since("2.0.0")
|
|
def evaluate(self, dataset):
|
|
"""
|
|
Evaluates the model on a test dataset.
|
|
|
|
:param dataset:
|
|
Test dataset to evaluate model on, where dataset is an
|
|
instance of :py:class:`pyspark.sql.DataFrame`
|
|
"""
|
|
if not isinstance(dataset, DataFrame):
|
|
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
|
|
java_lr_summary = self._call_java("evaluate", dataset)
|
|
return LinearRegressionSummary(java_lr_summary)
|
|
|
|
|
|
class LinearRegressionSummary(JavaWrapper):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Linear regression results evaluated on a dataset.
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def predictions(self):
|
|
"""
|
|
Dataframe outputted by the model's `transform` method.
|
|
"""
|
|
return self._call_java("predictions")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def predictionCol(self):
|
|
"""
|
|
Field in "predictions" which gives the predicted value of
|
|
the label at each instance.
|
|
"""
|
|
return self._call_java("predictionCol")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def labelCol(self):
|
|
"""
|
|
Field in "predictions" which gives the true label of each
|
|
instance.
|
|
"""
|
|
return self._call_java("labelCol")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def featuresCol(self):
|
|
"""
|
|
Field in "predictions" which gives the features of each instance
|
|
as a vector.
|
|
"""
|
|
return self._call_java("featuresCol")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def explainedVariance(self):
|
|
"""
|
|
Returns the explained variance regression score.
|
|
explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
|
|
|
|
.. seealso:: `Wikipedia explain variation \
|
|
<http://en.wikipedia.org/wiki/Explained_variation>`_
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark
|
|
versions.
|
|
"""
|
|
return self._call_java("explainedVariance")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def meanAbsoluteError(self):
|
|
"""
|
|
Returns the mean absolute error, which is a risk function
|
|
corresponding to the expected value of the absolute error
|
|
loss or l1-norm loss.
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark
|
|
versions.
|
|
"""
|
|
return self._call_java("meanAbsoluteError")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def meanSquaredError(self):
|
|
"""
|
|
Returns the mean squared error, which is a risk function
|
|
corresponding to the expected value of the squared error
|
|
loss or quadratic loss.
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark
|
|
versions.
|
|
"""
|
|
return self._call_java("meanSquaredError")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def rootMeanSquaredError(self):
|
|
"""
|
|
Returns the root mean squared error, which is defined as the
|
|
square root of the mean squared error.
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark
|
|
versions.
|
|
"""
|
|
return self._call_java("rootMeanSquaredError")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def r2(self):
|
|
"""
|
|
Returns R^2, the coefficient of determination.
|
|
|
|
.. seealso:: `Wikipedia coefficient of determination \
|
|
<http://en.wikipedia.org/wiki/Coefficient_of_determination>`_
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark
|
|
versions.
|
|
"""
|
|
return self._call_java("r2")
|
|
|
|
@property
|
|
@since("2.4.0")
|
|
def r2adj(self):
|
|
"""
|
|
Returns Adjusted R^2, the adjusted coefficient of determination.
|
|
|
|
.. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \
|
|
<https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2>`_
|
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
`LinearRegression.weightCol`. This will change in later Spark versions.
|
|
"""
|
|
return self._call_java("r2adj")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def residuals(self):
|
|
"""
|
|
Residuals (label - predicted value)
|
|
"""
|
|
return self._call_java("residuals")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def numInstances(self):
|
|
"""
|
|
Number of instances in DataFrame predictions
|
|
"""
|
|
return self._call_java("numInstances")
|
|
|
|
@property
|
|
@since("2.2.0")
|
|
def degreesOfFreedom(self):
|
|
"""
|
|
Degrees of freedom.
|
|
"""
|
|
return self._call_java("degreesOfFreedom")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def devianceResiduals(self):
|
|
"""
|
|
The weighted residuals, the usual residuals rescaled by the
|
|
square root of the instance weights.
|
|
"""
|
|
return self._call_java("devianceResiduals")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def coefficientStandardErrors(self):
|
|
"""
|
|
Standard error of estimated coefficients and intercept.
|
|
This value is only available when using the "normal" solver.
|
|
|
|
If :py:attr:`LinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
|
|
.. seealso:: :py:attr:`LinearRegression.solver`
|
|
"""
|
|
return self._call_java("coefficientStandardErrors")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def tValues(self):
|
|
"""
|
|
T-statistic of estimated coefficients and intercept.
|
|
This value is only available when using the "normal" solver.
|
|
|
|
If :py:attr:`LinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
|
|
.. seealso:: :py:attr:`LinearRegression.solver`
|
|
"""
|
|
return self._call_java("tValues")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def pValues(self):
|
|
"""
|
|
Two-sided p-value of estimated coefficients and intercept.
|
|
This value is only available when using the "normal" solver.
|
|
|
|
If :py:attr:`LinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
|
|
.. seealso:: :py:attr:`LinearRegression.solver`
|
|
"""
|
|
return self._call_java("pValues")
|
|
|
|
|
|
@inherit_doc
|
|
class LinearRegressionTrainingSummary(LinearRegressionSummary):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Linear regression training results. Currently, the training summary ignores the
|
|
training weights except for the objective trace.
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def objectiveHistory(self):
|
|
"""
|
|
Objective function (scaled loss + regularization) at each
|
|
iteration.
|
|
This value is only available when using the "l-bfgs" solver.
|
|
|
|
.. seealso:: :py:attr:`LinearRegression.solver`
|
|
"""
|
|
return self._call_java("objectiveHistory")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def totalIterations(self):
|
|
"""
|
|
Number of training iterations until termination.
|
|
This value is only available when using the "l-bfgs" solver.
|
|
|
|
.. seealso:: :py:attr:`LinearRegression.solver`
|
|
"""
|
|
return self._call_java("totalIterations")
|
|
|
|
|
|
@inherit_doc
|
|
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|
HasWeightCol, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Currently implemented using parallelized pool adjacent violators algorithm.
|
|
Only univariate (single feature) algorithm supported.
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(1.0)),
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
>>> ir = IsotonicRegression()
|
|
>>> model = ir.fit(df)
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
>>> model.transform(test0).head().prediction
|
|
0.0
|
|
>>> model.boundaries
|
|
DenseVector([0.0, 1.0])
|
|
>>> ir_path = temp_path + "/ir"
|
|
>>> ir.save(ir_path)
|
|
>>> ir2 = IsotonicRegression.load(ir_path)
|
|
>>> ir2.getIsotonic()
|
|
True
|
|
>>> model_path = temp_path + "/ir_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = IsotonicRegressionModel.load(model_path)
|
|
>>> model.boundaries == model2.boundaries
|
|
True
|
|
>>> model.predictions == model2.predictions
|
|
True
|
|
|
|
.. versionadded:: 1.6.0
|
|
"""
|
|
|
|
isotonic = \
|
|
Param(Params._dummy(), "isotonic",
|
|
"whether the output sequence should be isotonic/increasing (true) or" +
|
|
"antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean)
|
|
featureIndex = \
|
|
Param(Params._dummy(), "featureIndex",
|
|
"The index of the feature if featuresCol is a vector column, no effect otherwise.",
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
weightCol=None, isotonic=True, featureIndex=0):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
weightCol=None, isotonic=True, featureIndex=0):
|
|
"""
|
|
super(IsotonicRegression, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.IsotonicRegression", self.uid)
|
|
self._setDefault(isotonic=True, featureIndex=0)
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
weightCol=None, isotonic=True, featureIndex=0):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
weightCol=None, isotonic=True, featureIndex=0):
|
|
Set the params for IsotonicRegression.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return IsotonicRegressionModel(java_model)
|
|
|
|
def setIsotonic(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`isotonic`.
|
|
"""
|
|
return self._set(isotonic=value)
|
|
|
|
def getIsotonic(self):
|
|
"""
|
|
Gets the value of isotonic or its default value.
|
|
"""
|
|
return self.getOrDefault(self.isotonic)
|
|
|
|
def setFeatureIndex(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`featureIndex`.
|
|
"""
|
|
return self._set(featureIndex=value)
|
|
|
|
def getFeatureIndex(self):
|
|
"""
|
|
Gets the value of featureIndex or its default value.
|
|
"""
|
|
return self.getOrDefault(self.featureIndex)
|
|
|
|
|
|
class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Model fitted by :class:`IsotonicRegression`.
|
|
|
|
.. versionadded:: 1.6.0
|
|
"""
|
|
|
|
@property
|
|
@since("1.6.0")
|
|
def boundaries(self):
|
|
"""
|
|
Boundaries in increasing order for which predictions are known.
|
|
"""
|
|
return self._call_java("boundaries")
|
|
|
|
@property
|
|
@since("1.6.0")
|
|
def predictions(self):
|
|
"""
|
|
Predictions associated with the boundaries at the same index, monotone because of isotonic
|
|
regression.
|
|
"""
|
|
return self._call_java("predictions")
|
|
|
|
|
|
class TreeEnsembleParams(DecisionTreeParams):
|
|
"""
|
|
Mixin for Decision Tree-based ensemble algorithms parameters.
|
|
"""
|
|
|
|
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
|
|
"used for learning each decision tree, in range (0, 1].",
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
|
|
|
|
featureSubsetStrategy = \
|
|
Param(Params._dummy(), "featureSubsetStrategy",
|
|
"The number of features to consider for splits at each tree node. Supported " +
|
|
"options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
|
|
"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
|
|
"'onethird' for regression), 'all' (use all features), 'onethird' (use " +
|
|
"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
|
|
"log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
|
|
"n * number of features. When n is in the range (1, number of features), use" +
|
|
" n features). default = 'auto'", typeConverter=TypeConverters.toString)
|
|
|
|
def __init__(self):
|
|
super(TreeEnsembleParams, self).__init__()
|
|
|
|
@since("1.4.0")
|
|
def setSubsamplingRate(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`subsamplingRate`.
|
|
"""
|
|
return self._set(subsamplingRate=value)
|
|
|
|
@since("1.4.0")
|
|
def getSubsamplingRate(self):
|
|
"""
|
|
Gets the value of subsamplingRate or its default value.
|
|
"""
|
|
return self.getOrDefault(self.subsamplingRate)
|
|
|
|
@since("1.4.0")
|
|
def setFeatureSubsetStrategy(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`featureSubsetStrategy`.
|
|
|
|
.. note:: Deprecated in 2.4.0 and will be removed in 3.0.0.
|
|
"""
|
|
return self._set(featureSubsetStrategy=value)
|
|
|
|
@since("1.4.0")
|
|
def getFeatureSubsetStrategy(self):
|
|
"""
|
|
Gets the value of featureSubsetStrategy or its default value.
|
|
"""
|
|
return self.getOrDefault(self.featureSubsetStrategy)
|
|
|
|
|
|
class TreeRegressorParams(Params):
|
|
"""
|
|
Private class to track supported impurity measures.
|
|
"""
|
|
|
|
supportedImpurities = ["variance"]
|
|
impurity = Param(Params._dummy(), "impurity",
|
|
"Criterion used for information gain calculation (case-insensitive). " +
|
|
"Supported options: " +
|
|
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
|
|
|
def __init__(self):
|
|
super(TreeRegressorParams, self).__init__()
|
|
|
|
@since("1.4.0")
|
|
def setImpurity(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`impurity`.
|
|
"""
|
|
return self._set(impurity=value)
|
|
|
|
@since("1.4.0")
|
|
def getImpurity(self):
|
|
"""
|
|
Gets the value of impurity or its default value.
|
|
"""
|
|
return self.getOrDefault(self.impurity)
|
|
|
|
|
|
class RandomForestParams(TreeEnsembleParams):
|
|
"""
|
|
Private class to track supported random forest parameters.
|
|
"""
|
|
|
|
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
|
|
typeConverter=TypeConverters.toInt)
|
|
|
|
def __init__(self):
|
|
super(RandomForestParams, self).__init__()
|
|
|
|
@since("1.4.0")
|
|
def setNumTrees(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`numTrees`.
|
|
"""
|
|
return self._set(numTrees=value)
|
|
|
|
@since("1.4.0")
|
|
def getNumTrees(self):
|
|
"""
|
|
Gets the value of numTrees or its default value.
|
|
"""
|
|
return self.getOrDefault(self.numTrees)
|
|
|
|
|
|
class GBTParams(TreeEnsembleParams):
|
|
"""
|
|
Private class to track supported GBT params.
|
|
"""
|
|
supportedLossTypes = ["squared", "absolute"]
|
|
|
|
|
|
@inherit_doc
|
|
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
|
|
HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol):
|
|
"""
|
|
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
|
learning algorithm for regression.
|
|
It supports both continuous and categorical features.
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(1.0)),
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
>>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
|
|
>>> model = dt.fit(df)
|
|
>>> model.depth
|
|
1
|
|
>>> model.numNodes
|
|
3
|
|
>>> model.featureImportances
|
|
SparseVector(1, {0: 1.0})
|
|
>>> model.numFeatures
|
|
1
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
>>> model.transform(test0).head().prediction
|
|
0.0
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
>>> model.transform(test1).head().prediction
|
|
1.0
|
|
>>> dtr_path = temp_path + "/dtr"
|
|
>>> dt.save(dtr_path)
|
|
>>> dt2 = DecisionTreeRegressor.load(dtr_path)
|
|
>>> dt2.getMaxDepth()
|
|
2
|
|
>>> model_path = temp_path + "/dtr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = DecisionTreeRegressionModel.load(model_path)
|
|
>>> model.numNodes == model2.numNodes
|
|
True
|
|
>>> model.depth == model2.depth
|
|
True
|
|
>>> model.transform(test1).head().variance
|
|
0.0
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
|
|
seed=None, varianceCol=None):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
|
impurity="variance", seed=None, varianceCol=None)
|
|
"""
|
|
super(DecisionTreeRegressor, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
impurity="variance")
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.4.0")
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
impurity="variance", seed=None, varianceCol=None):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
|
impurity="variance", seed=None, varianceCol=None)
|
|
Sets params for the DecisionTreeRegressor.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return DecisionTreeRegressionModel(java_model)
|
|
|
|
|
|
@inherit_doc
|
|
class DecisionTreeModel(JavaModel, JavaPredictionModel):
|
|
"""
|
|
Abstraction for Decision Tree models.
|
|
|
|
.. versionadded:: 1.5.0
|
|
"""
|
|
|
|
@property
|
|
@since("1.5.0")
|
|
def numNodes(self):
|
|
"""Return number of nodes of the decision tree."""
|
|
return self._call_java("numNodes")
|
|
|
|
@property
|
|
@since("1.5.0")
|
|
def depth(self):
|
|
"""Return depth of the decision tree."""
|
|
return self._call_java("depth")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def toDebugString(self):
|
|
"""Full description of model."""
|
|
return self._call_java("toDebugString")
|
|
|
|
def __repr__(self):
|
|
return self._call_java("toString")
|
|
|
|
|
|
@inherit_doc
|
|
class TreeEnsembleModel(JavaModel):
|
|
"""
|
|
(private abstraction)
|
|
|
|
Represents a tree ensemble model.
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def trees(self):
|
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def getNumTrees(self):
|
|
"""Number of trees in ensemble."""
|
|
return self._call_java("getNumTrees")
|
|
|
|
@property
|
|
@since("1.5.0")
|
|
def treeWeights(self):
|
|
"""Return the weights for each tree"""
|
|
return list(self._call_java("javaTreeWeights"))
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def totalNumNodes(self):
|
|
"""Total number of nodes, summed over all trees in the ensemble."""
|
|
return self._call_java("totalNumNodes")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def toDebugString(self):
|
|
"""Full description of model."""
|
|
return self._call_java("toDebugString")
|
|
|
|
def __repr__(self):
|
|
return self._call_java("toString")
|
|
|
|
|
|
@inherit_doc
|
|
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Model fitted by :class:`DecisionTreeRegressor`.
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def featureImportances(self):
|
|
"""
|
|
Estimate of the importance of each feature.
|
|
|
|
This generalizes the idea of "Gini" importance to other losses,
|
|
following the explanation of Gini importance from "Random Forests" documentation
|
|
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
|
|
|
This feature importance is calculated as follows:
|
|
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
|
where gain is scaled by the number of instances passing through node
|
|
- Normalize importances for tree to sum to 1.
|
|
|
|
.. note:: Feature importance for single decision trees can have high variance due to
|
|
correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
|
|
to determine feature importance instead.
|
|
"""
|
|
return self._call_java("featureImportances")
|
|
|
|
|
|
@inherit_doc
|
|
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
|
|
RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
|
|
JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
|
learning algorithm for regression.
|
|
It supports both continuous and categorical features.
|
|
|
|
>>> from numpy import allclose
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(1.0)),
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
|
|
>>> model = rf.fit(df)
|
|
>>> model.featureImportances
|
|
SparseVector(1, {0: 1.0})
|
|
>>> allclose(model.treeWeights, [1.0, 1.0])
|
|
True
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
>>> model.transform(test0).head().prediction
|
|
0.0
|
|
>>> model.numFeatures
|
|
1
|
|
>>> model.trees
|
|
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
|
>>> model.getNumTrees
|
|
2
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
>>> model.transform(test1).head().prediction
|
|
0.5
|
|
>>> rfr_path = temp_path + "/rfr"
|
|
>>> rf.save(rfr_path)
|
|
>>> rf2 = RandomForestRegressor.load(rfr_path)
|
|
>>> rf2.getNumTrees()
|
|
2
|
|
>>> model_path = temp_path + "/rfr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = RandomForestRegressionModel.load(model_path)
|
|
>>> model.featureImportances == model2.featureImportances
|
|
True
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
|
featureSubsetStrategy="auto"):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
|
featureSubsetStrategy="auto")
|
|
"""
|
|
super(RandomForestRegressor, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.RandomForestRegressor", self.uid)
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
impurity="variance", subsamplingRate=1.0, numTrees=20,
|
|
featureSubsetStrategy="auto")
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.4.0")
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
|
featureSubsetStrategy="auto"):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
|
featureSubsetStrategy="auto")
|
|
Sets params for linear regression.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return RandomForestRegressionModel(java_model)
|
|
|
|
@since("2.4.0")
|
|
def setFeatureSubsetStrategy(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`featureSubsetStrategy`.
|
|
"""
|
|
return self._set(featureSubsetStrategy=value)
|
|
|
|
|
|
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
|
JavaMLReadable):
|
|
"""
|
|
Model fitted by :class:`RandomForestRegressor`.
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def trees(self):
|
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def featureImportances(self):
|
|
"""
|
|
Estimate of the importance of each feature.
|
|
|
|
Each feature's importance is the average of its importance across all trees in the ensemble
|
|
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
|
|
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
|
|
and follows the implementation from scikit-learn.
|
|
|
|
.. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
|
|
"""
|
|
return self._call_java("featureImportances")
|
|
|
|
|
|
@inherit_doc
|
|
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
|
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
|
|
JavaMLReadable, TreeRegressorParams):
|
|
"""
|
|
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
|
learning algorithm for regression.
|
|
It supports both continuous and categorical features.
|
|
|
|
>>> from numpy import allclose
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(1.0)),
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
|
|
>>> print(gbt.getImpurity())
|
|
variance
|
|
>>> print(gbt.getFeatureSubsetStrategy())
|
|
all
|
|
>>> model = gbt.fit(df)
|
|
>>> model.featureImportances
|
|
SparseVector(1, {0: 1.0})
|
|
>>> model.numFeatures
|
|
1
|
|
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
|
|
True
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
>>> model.transform(test0).head().prediction
|
|
0.0
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
>>> model.transform(test1).head().prediction
|
|
1.0
|
|
>>> gbtr_path = temp_path + "gbtr"
|
|
>>> gbt.save(gbtr_path)
|
|
>>> gbt2 = GBTRegressor.load(gbtr_path)
|
|
>>> gbt2.getMaxDepth()
|
|
2
|
|
>>> model_path = temp_path + "gbtr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = GBTRegressionModel.load(model_path)
|
|
>>> model.featureImportances == model2.featureImportances
|
|
True
|
|
>>> model.treeWeights == model2.treeWeights
|
|
True
|
|
>>> model.trees
|
|
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
|
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
|
|
... ["label", "features"])
|
|
>>> model.evaluateEachIteration(validation, "squared")
|
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
lossType = Param(Params._dummy(), "lossType",
|
|
"Loss function which GBT tries to minimize (case-insensitive). " +
|
|
"Supported options: " + ", ".join(GBTParams.supportedLossTypes),
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
stepSize = Param(Params._dummy(), "stepSize",
|
|
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
|
|
"the contribution of each estimator.",
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
|
impurity="variance", featureSubsetStrategy="all"):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
|
impurity="variance", featureSubsetStrategy="all")
|
|
"""
|
|
super(GBTRegressor, self).__init__()
|
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
|
|
impurity="variance", featureSubsetStrategy="all")
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.4.0")
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
|
impuriy="variance", featureSubsetStrategy="all"):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
|
impurity="variance", featureSubsetStrategy="all")
|
|
Sets params for Gradient Boosted Tree Regression.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return GBTRegressionModel(java_model)
|
|
|
|
@since("1.4.0")
|
|
def setLossType(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`lossType`.
|
|
"""
|
|
return self._set(lossType=value)
|
|
|
|
@since("1.4.0")
|
|
def getLossType(self):
|
|
"""
|
|
Gets the value of lossType or its default value.
|
|
"""
|
|
return self.getOrDefault(self.lossType)
|
|
|
|
@since("2.4.0")
|
|
def setFeatureSubsetStrategy(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`featureSubsetStrategy`.
|
|
"""
|
|
return self._set(featureSubsetStrategy=value)
|
|
|
|
|
|
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
Model fitted by :class:`GBTRegressor`.
|
|
|
|
.. versionadded:: 1.4.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def featureImportances(self):
|
|
"""
|
|
Estimate of the importance of each feature.
|
|
|
|
Each feature's importance is the average of its importance across all trees in the ensemble
|
|
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
|
|
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
|
|
and follows the implementation from scikit-learn.
|
|
|
|
.. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
|
|
"""
|
|
return self._call_java("featureImportances")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def trees(self):
|
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
|
|
|
|
@since("2.4.0")
|
|
def evaluateEachIteration(self, dataset, loss):
|
|
"""
|
|
Method to compute error or loss for every iteration of gradient boosting.
|
|
|
|
:param dataset:
|
|
Test dataset to evaluate model on, where dataset is an
|
|
instance of :py:class:`pyspark.sql.DataFrame`
|
|
:param loss:
|
|
The loss function used to compute error.
|
|
Supported options: squared, absolute
|
|
"""
|
|
return self._call_java("evaluateEachIteration", dataset, loss)
|
|
|
|
|
|
@inherit_doc
|
|
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|
HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth,
|
|
JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Accelerated Failure Time (AFT) Model Survival Regression
|
|
|
|
Fit a parametric AFT survival regression model based on the Weibull distribution
|
|
of the survival time.
|
|
|
|
.. seealso:: `AFT Model <https://en.wikipedia.org/wiki/Accelerated_failure_time_model>`_
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(1.0), 1.0),
|
|
... (1e-40, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
|
|
>>> aftsr = AFTSurvivalRegression()
|
|
>>> model = aftsr.fit(df)
|
|
>>> model.predict(Vectors.dense(6.3))
|
|
1.0
|
|
>>> model.predictQuantiles(Vectors.dense(6.3))
|
|
DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
|
|
>>> model.transform(df).show()
|
|
+-------+---------+------+----------+
|
|
| label| features|censor|prediction|
|
|
+-------+---------+------+----------+
|
|
| 1.0| [1.0]| 1.0| 1.0|
|
|
|1.0E-40|(1,[],[])| 0.0| 1.0|
|
|
+-------+---------+------+----------+
|
|
...
|
|
>>> aftsr_path = temp_path + "/aftsr"
|
|
>>> aftsr.save(aftsr_path)
|
|
>>> aftsr2 = AFTSurvivalRegression.load(aftsr_path)
|
|
>>> aftsr2.getMaxIter()
|
|
100
|
|
>>> model_path = temp_path + "/aftsr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = AFTSurvivalRegressionModel.load(model_path)
|
|
>>> model.coefficients == model2.coefficients
|
|
True
|
|
>>> model.intercept == model2.intercept
|
|
True
|
|
>>> model.scale == model2.scale
|
|
True
|
|
|
|
.. versionadded:: 1.6.0
|
|
"""
|
|
|
|
censorCol = Param(Params._dummy(), "censorCol",
|
|
"censor column name. The value of this column could be 0 or 1. " +
|
|
"If the value is 1, it means the event has occurred i.e. " +
|
|
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
|
|
quantileProbabilities = \
|
|
Param(Params._dummy(), "quantileProbabilities",
|
|
"quantile probabilities array. Values of the quantile probabilities array " +
|
|
"should be in the range (0, 1) and the array should be non-empty.",
|
|
typeConverter=TypeConverters.toListFloat)
|
|
quantilesCol = Param(Params._dummy(), "quantilesCol",
|
|
"quantiles column name. This column will output quantiles of " +
|
|
"corresponding quantileProbabilities if it is set.",
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
@keyword_only
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
|
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
|
|
quantilesCol=None, aggregationDepth=2):
|
|
"""
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
|
|
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
|
|
quantilesCol=None, aggregationDepth=2)
|
|
"""
|
|
super(AFTSurvivalRegression, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
|
|
self._setDefault(censorCol="censor",
|
|
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
|
|
maxIter=100, tol=1E-6)
|
|
kwargs = self._input_kwargs
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("1.6.0")
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
|
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
|
|
quantilesCol=None, aggregationDepth=2):
|
|
"""
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
|
|
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
|
|
quantilesCol=None, aggregationDepth=2):
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return AFTSurvivalRegressionModel(java_model)
|
|
|
|
@since("1.6.0")
|
|
def setCensorCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`censorCol`.
|
|
"""
|
|
return self._set(censorCol=value)
|
|
|
|
@since("1.6.0")
|
|
def getCensorCol(self):
|
|
"""
|
|
Gets the value of censorCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.censorCol)
|
|
|
|
@since("1.6.0")
|
|
def setQuantileProbabilities(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`quantileProbabilities`.
|
|
"""
|
|
return self._set(quantileProbabilities=value)
|
|
|
|
@since("1.6.0")
|
|
def getQuantileProbabilities(self):
|
|
"""
|
|
Gets the value of quantileProbabilities or its default value.
|
|
"""
|
|
return self.getOrDefault(self.quantileProbabilities)
|
|
|
|
@since("1.6.0")
|
|
def setQuantilesCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`quantilesCol`.
|
|
"""
|
|
return self._set(quantilesCol=value)
|
|
|
|
@since("1.6.0")
|
|
def getQuantilesCol(self):
|
|
"""
|
|
Gets the value of quantilesCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.quantilesCol)
|
|
|
|
|
|
class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Model fitted by :class:`AFTSurvivalRegression`.
|
|
|
|
.. versionadded:: 1.6.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def coefficients(self):
|
|
"""
|
|
Model coefficients.
|
|
"""
|
|
return self._call_java("coefficients")
|
|
|
|
@property
|
|
@since("1.6.0")
|
|
def intercept(self):
|
|
"""
|
|
Model intercept.
|
|
"""
|
|
return self._call_java("intercept")
|
|
|
|
@property
|
|
@since("1.6.0")
|
|
def scale(self):
|
|
"""
|
|
Model scale parameter.
|
|
"""
|
|
return self._call_java("scale")
|
|
|
|
@since("2.0.0")
|
|
def predictQuantiles(self, features):
|
|
"""
|
|
Predicted Quantiles
|
|
"""
|
|
return self._call_java("predictQuantiles", features)
|
|
|
|
@since("2.0.0")
|
|
def predict(self, features):
|
|
"""
|
|
Predicted value
|
|
"""
|
|
return self._call_java("predict", features)
|
|
|
|
|
|
@inherit_doc
|
|
class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
|
|
HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
|
|
HasSolver, JavaMLWritable, JavaMLReadable):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Generalized Linear Regression.
|
|
|
|
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
|
|
predictor (link function) and a description of the error distribution (family). It supports
|
|
"gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
|
|
each family is listed below. The first link function of each family is the default one.
|
|
|
|
* "gaussian" -> "identity", "log", "inverse"
|
|
|
|
* "binomial" -> "logit", "probit", "cloglog"
|
|
|
|
* "poisson" -> "log", "identity", "sqrt"
|
|
|
|
* "gamma" -> "inverse", "identity", "log"
|
|
|
|
* "tweedie" -> power link function specified through "linkPower". \
|
|
The default link power in the tweedie family is 1 - variancePower.
|
|
|
|
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
>>> df = spark.createDataFrame([
|
|
... (1.0, Vectors.dense(0.0, 0.0)),
|
|
... (1.0, Vectors.dense(1.0, 2.0)),
|
|
... (2.0, Vectors.dense(0.0, 0.0)),
|
|
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
|
|
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
|
|
>>> model = glr.fit(df)
|
|
>>> transformed = model.transform(df)
|
|
>>> abs(transformed.head().prediction - 1.5) < 0.001
|
|
True
|
|
>>> abs(transformed.head().p - 1.5) < 0.001
|
|
True
|
|
>>> model.coefficients
|
|
DenseVector([1.5..., -1.0...])
|
|
>>> model.numFeatures
|
|
2
|
|
>>> abs(model.intercept - 1.5) < 0.001
|
|
True
|
|
>>> glr_path = temp_path + "/glr"
|
|
>>> glr.save(glr_path)
|
|
>>> glr2 = GeneralizedLinearRegression.load(glr_path)
|
|
>>> glr.getFamily() == glr2.getFamily()
|
|
True
|
|
>>> model_path = temp_path + "/glr_model"
|
|
>>> model.save(model_path)
|
|
>>> model2 = GeneralizedLinearRegressionModel.load(model_path)
|
|
>>> model.intercept == model2.intercept
|
|
True
|
|
>>> model.coefficients[0] == model2.coefficients[0]
|
|
True
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
|
|
"the error distribution to be used in the model. Supported options: " +
|
|
"gaussian (default), binomial, poisson, gamma and tweedie.",
|
|
typeConverter=TypeConverters.toString)
|
|
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
|
|
"relationship between the linear predictor and the mean of the distribution " +
|
|
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
|
|
"and sqrt.", typeConverter=TypeConverters.toString)
|
|
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
|
|
"predictor) column name", typeConverter=TypeConverters.toString)
|
|
variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
|
|
"of the Tweedie distribution which characterizes the relationship " +
|
|
"between the variance and mean of the distribution. Only applicable " +
|
|
"for the Tweedie family. Supported values: 0 and [1, Inf).",
|
|
typeConverter=TypeConverters.toFloat)
|
|
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
|
|
"Only applicable to the Tweedie family.",
|
|
typeConverter=TypeConverters.toFloat)
|
|
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
|
"options: irls.", typeConverter=TypeConverters.toString)
|
|
offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
|
|
"or empty, we treat all instance offsets as 0.0",
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
@keyword_only
|
|
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
|
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
|
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
|
|
variancePower=0.0, linkPower=None, offsetCol=None):
|
|
"""
|
|
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
|
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
|
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
|
|
variancePower=0.0, linkPower=None, offsetCol=None)
|
|
"""
|
|
super(GeneralizedLinearRegression, self).__init__()
|
|
self._java_obj = self._new_java_obj(
|
|
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
|
|
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
|
|
variancePower=0.0)
|
|
kwargs = self._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
@keyword_only
|
|
@since("2.0.0")
|
|
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
|
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
|
|
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
|
|
variancePower=0.0, linkPower=None, offsetCol=None):
|
|
"""
|
|
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
|
|
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
|
|
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
|
|
variancePower=0.0, linkPower=None, offsetCol=None)
|
|
Sets params for generalized linear regression.
|
|
"""
|
|
kwargs = self._input_kwargs
|
|
return self._set(**kwargs)
|
|
|
|
def _create_model(self, java_model):
|
|
return GeneralizedLinearRegressionModel(java_model)
|
|
|
|
@since("2.0.0")
|
|
def setFamily(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`family`.
|
|
"""
|
|
return self._set(family=value)
|
|
|
|
@since("2.0.0")
|
|
def getFamily(self):
|
|
"""
|
|
Gets the value of family or its default value.
|
|
"""
|
|
return self.getOrDefault(self.family)
|
|
|
|
@since("2.0.0")
|
|
def setLinkPredictionCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`linkPredictionCol`.
|
|
"""
|
|
return self._set(linkPredictionCol=value)
|
|
|
|
@since("2.0.0")
|
|
def getLinkPredictionCol(self):
|
|
"""
|
|
Gets the value of linkPredictionCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.linkPredictionCol)
|
|
|
|
@since("2.0.0")
|
|
def setLink(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`link`.
|
|
"""
|
|
return self._set(link=value)
|
|
|
|
@since("2.0.0")
|
|
def getLink(self):
|
|
"""
|
|
Gets the value of link or its default value.
|
|
"""
|
|
return self.getOrDefault(self.link)
|
|
|
|
@since("2.2.0")
|
|
def setVariancePower(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`variancePower`.
|
|
"""
|
|
return self._set(variancePower=value)
|
|
|
|
@since("2.2.0")
|
|
def getVariancePower(self):
|
|
"""
|
|
Gets the value of variancePower or its default value.
|
|
"""
|
|
return self.getOrDefault(self.variancePower)
|
|
|
|
@since("2.2.0")
|
|
def setLinkPower(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`linkPower`.
|
|
"""
|
|
return self._set(linkPower=value)
|
|
|
|
@since("2.2.0")
|
|
def getLinkPower(self):
|
|
"""
|
|
Gets the value of linkPower or its default value.
|
|
"""
|
|
return self.getOrDefault(self.linkPower)
|
|
|
|
@since("2.3.0")
|
|
def setOffsetCol(self, value):
|
|
"""
|
|
Sets the value of :py:attr:`offsetCol`.
|
|
"""
|
|
return self._set(offsetCol=value)
|
|
|
|
@since("2.3.0")
|
|
def getOffsetCol(self):
|
|
"""
|
|
Gets the value of offsetCol or its default value.
|
|
"""
|
|
return self.getOrDefault(self.offsetCol)
|
|
|
|
|
|
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
|
|
JavaMLReadable):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Model fitted by :class:`GeneralizedLinearRegression`.
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def coefficients(self):
|
|
"""
|
|
Model coefficients.
|
|
"""
|
|
return self._call_java("coefficients")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def intercept(self):
|
|
"""
|
|
Model intercept.
|
|
"""
|
|
return self._call_java("intercept")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def summary(self):
|
|
"""
|
|
Gets summary (e.g. residuals, deviance, pValues) of model on
|
|
training set. An exception is thrown if
|
|
`trainingSummary is None`.
|
|
"""
|
|
if self.hasSummary:
|
|
java_glrt_summary = self._call_java("summary")
|
|
return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary)
|
|
else:
|
|
raise RuntimeError("No training summary available for this %s" %
|
|
self.__class__.__name__)
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def hasSummary(self):
|
|
"""
|
|
Indicates whether a training summary exists for this model
|
|
instance.
|
|
"""
|
|
return self._call_java("hasSummary")
|
|
|
|
@since("2.0.0")
|
|
def evaluate(self, dataset):
|
|
"""
|
|
Evaluates the model on a test dataset.
|
|
|
|
:param dataset:
|
|
Test dataset to evaluate model on, where dataset is an
|
|
instance of :py:class:`pyspark.sql.DataFrame`
|
|
"""
|
|
if not isinstance(dataset, DataFrame):
|
|
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
|
|
java_glr_summary = self._call_java("evaluate", dataset)
|
|
return GeneralizedLinearRegressionSummary(java_glr_summary)
|
|
|
|
|
|
class GeneralizedLinearRegressionSummary(JavaWrapper):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Generalized linear regression results evaluated on a dataset.
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def predictions(self):
|
|
"""
|
|
Predictions output by the model's `transform` method.
|
|
"""
|
|
return self._call_java("predictions")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def predictionCol(self):
|
|
"""
|
|
Field in :py:attr:`predictions` which gives the predicted value of each instance.
|
|
This is set to a new column name if the original model's `predictionCol` is not set.
|
|
"""
|
|
return self._call_java("predictionCol")
|
|
|
|
@property
|
|
@since("2.2.0")
|
|
def numInstances(self):
|
|
"""
|
|
Number of instances in DataFrame predictions.
|
|
"""
|
|
return self._call_java("numInstances")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def rank(self):
|
|
"""
|
|
The numeric rank of the fitted linear model.
|
|
"""
|
|
return self._call_java("rank")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def degreesOfFreedom(self):
|
|
"""
|
|
Degrees of freedom.
|
|
"""
|
|
return self._call_java("degreesOfFreedom")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def residualDegreeOfFreedom(self):
|
|
"""
|
|
The residual degrees of freedom.
|
|
"""
|
|
return self._call_java("residualDegreeOfFreedom")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def residualDegreeOfFreedomNull(self):
|
|
"""
|
|
The residual degrees of freedom for the null model.
|
|
"""
|
|
return self._call_java("residualDegreeOfFreedomNull")
|
|
|
|
@since("2.0.0")
|
|
def residuals(self, residualsType="deviance"):
|
|
"""
|
|
Get the residuals of the fitted model by type.
|
|
|
|
:param residualsType: The type of residuals which should be returned.
|
|
Supported options: deviance (default), pearson, working, and response.
|
|
"""
|
|
return self._call_java("residuals", residualsType)
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def nullDeviance(self):
|
|
"""
|
|
The deviance for the null model.
|
|
"""
|
|
return self._call_java("nullDeviance")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def deviance(self):
|
|
"""
|
|
The deviance for the fitted model.
|
|
"""
|
|
return self._call_java("deviance")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def dispersion(self):
|
|
"""
|
|
The dispersion of the fitted model.
|
|
It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise
|
|
estimated by the residual Pearson's Chi-Squared statistic (which is defined as
|
|
sum of the squares of the Pearson residuals) divided by the residual degrees of freedom.
|
|
"""
|
|
return self._call_java("dispersion")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def aic(self):
|
|
"""
|
|
Akaike's "An Information Criterion"(AIC) for the fitted model.
|
|
"""
|
|
return self._call_java("aic")
|
|
|
|
|
|
@inherit_doc
|
|
class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary):
|
|
"""
|
|
.. note:: Experimental
|
|
|
|
Generalized linear regression training results.
|
|
|
|
.. versionadded:: 2.0.0
|
|
"""
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def numIterations(self):
|
|
"""
|
|
Number of training iterations.
|
|
"""
|
|
return self._call_java("numIterations")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def solver(self):
|
|
"""
|
|
The numeric solver used for training.
|
|
"""
|
|
return self._call_java("solver")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def coefficientStandardErrors(self):
|
|
"""
|
|
Standard error of estimated coefficients and intercept.
|
|
|
|
If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
"""
|
|
return self._call_java("coefficientStandardErrors")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def tValues(self):
|
|
"""
|
|
T-statistic of estimated coefficients and intercept.
|
|
|
|
If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
"""
|
|
return self._call_java("tValues")
|
|
|
|
@property
|
|
@since("2.0.0")
|
|
def pValues(self):
|
|
"""
|
|
Two-sided p-value of estimated coefficients and intercept.
|
|
|
|
If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True,
|
|
then the last element returned corresponds to the intercept.
|
|
"""
|
|
return self._call_java("pValues")
|
|
|
|
def __repr__(self):
|
|
return self._call_java("toString")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import doctest
|
|
import pyspark.ml.regression
|
|
from pyspark.sql import SparkSession
|
|
globs = pyspark.ml.regression.__dict__.copy()
|
|
# The small batch size here ensures that we see multiple batches,
|
|
# even in these small test examples:
|
|
spark = SparkSession.builder\
|
|
.master("local[2]")\
|
|
.appName("ml.regression tests")\
|
|
.getOrCreate()
|
|
sc = spark.sparkContext
|
|
globs['sc'] = sc
|
|
globs['spark'] = spark
|
|
import tempfile
|
|
temp_path = tempfile.mkdtemp()
|
|
globs['temp_path'] = temp_path
|
|
try:
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
|
spark.stop()
|
|
finally:
|
|
from shutil import rmtree
|
|
try:
|
|
rmtree(temp_path)
|
|
except OSError:
|
|
pass
|
|
if failure_count:
|
|
sys.exit(-1)
|