2015-01-28 20:14:23 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2016-04-15 15:58:38 -04:00
|
|
|
import operator
|
2018-03-08 06:38:34 -05:00
|
|
|
import sys
|
2017-09-12 13:02:27 -04:00
|
|
|
from multiprocessing.pool import ThreadPool
|
2015-11-02 19:12:04 -05:00
|
|
|
|
2016-04-20 13:32:01 -04:00
|
|
|
from pyspark import since, keyword_only
|
2016-04-15 15:58:38 -04:00
|
|
|
from pyspark.ml import Estimator, Model
|
2015-05-13 18:13:09 -04:00
|
|
|
from pyspark.ml.param.shared import *
|
2019-08-15 11:21:26 -04:00
|
|
|
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, \
|
|
|
|
DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, \
|
|
|
|
TreeEnsembleModel
|
2016-04-15 15:58:38 -04:00
|
|
|
from pyspark.ml.util import *
|
2016-04-18 14:52:29 -04:00
|
|
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
|
2016-04-15 15:58:38 -04:00
|
|
|
from pyspark.ml.wrapper import JavaWrapper
|
2017-07-17 13:07:32 -04:00
|
|
|
from pyspark.ml.common import inherit_doc, _java2py, _py2java
|
2019-03-02 10:09:28 -05:00
|
|
|
from pyspark.ml.linalg import Vectors
|
2016-04-06 15:07:47 -04:00
|
|
|
from pyspark.sql import DataFrame
|
2016-04-15 15:58:38 -04:00
|
|
|
from pyspark.sql.functions import udf, when
|
|
|
|
from pyspark.sql.types import ArrayType, DoubleType
|
|
|
|
from pyspark.storagelevel import StorageLevel
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2017-01-27 19:03:53 -05:00
|
|
|
__all__ = ['LinearSVC', 'LinearSVCModel',
|
|
|
|
'LogisticRegression', 'LogisticRegressionModel',
|
2016-04-06 15:07:47 -04:00
|
|
|
'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
|
|
|
|
'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
|
2016-03-02 00:26:47 -05:00
|
|
|
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
|
|
|
|
'GBTClassifier', 'GBTClassificationModel',
|
|
|
|
'RandomForestClassifier', 'RandomForestClassificationModel',
|
|
|
|
'NaiveBayes', 'NaiveBayesModel',
|
2016-04-15 15:58:38 -04:00
|
|
|
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel',
|
|
|
|
'OneVsRest', 'OneVsRestModel']
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
|
2016-08-22 06:21:22 -04:00
|
|
|
@inherit_doc
|
|
|
|
class JavaClassificationModel(JavaPredictionModel):
|
|
|
|
"""
|
|
|
|
(Private) Java Model produced by a ``Classifier``.
|
|
|
|
Classes are indexed {0, 1, ..., numClasses - 1}.
|
|
|
|
To be mixed in with class:`pyspark.ml.JavaModel`
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.1.0")
|
|
|
|
def numClasses(self):
|
|
|
|
"""
|
|
|
|
Number of classes (values which the label can take).
|
|
|
|
"""
|
|
|
|
return self._call_java("numClasses")
|
|
|
|
|
|
|
|
|
2017-01-27 19:03:53 -05:00
|
|
|
@inherit_doc
|
|
|
|
class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
|
|
|
HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization,
|
2017-08-01 09:34:26 -04:00
|
|
|
HasWeightCol, HasAggregationDepth, HasThreshold, JavaMLWritable, JavaMLReadable):
|
2017-01-27 19:03:53 -05:00
|
|
|
"""
|
|
|
|
`Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
|
2017-02-02 14:58:46 -05:00
|
|
|
|
2017-01-27 19:03:53 -05:00
|
|
|
This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
|
2017-05-16 00:21:54 -04:00
|
|
|
Only supports L2 regularization currently.
|
2017-01-27 19:03:53 -05:00
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
|
|
|
>>> df = sc.parallelize([
|
|
|
|
... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
|
|
|
|
... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
|
|
|
|
>>> svm = LinearSVC(maxIter=5, regParam=0.01)
|
|
|
|
>>> model = svm.fit(df)
|
|
|
|
>>> model.coefficients
|
|
|
|
DenseVector([0.0, -0.2792, -0.1833])
|
|
|
|
>>> model.intercept
|
|
|
|
1.0206118982229047
|
|
|
|
>>> model.numClasses
|
|
|
|
2
|
|
|
|
>>> model.numFeatures
|
|
|
|
3
|
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
|
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
|
|
|
1.0
|
|
|
|
>>> result.rawPrediction
|
|
|
|
DenseVector([-1.4831, 1.4831])
|
|
|
|
>>> svm_path = temp_path + "/svm"
|
|
|
|
>>> svm.save(svm_path)
|
|
|
|
>>> svm2 = LinearSVC.load(svm_path)
|
|
|
|
>>> svm2.getMaxIter()
|
|
|
|
5
|
|
|
|
>>> model_path = temp_path + "/svm_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = LinearSVCModel.load(model_path)
|
|
|
|
>>> model.coefficients[0] == model2.coefficients[0]
|
|
|
|
True
|
|
|
|
>>> model.intercept == model2.intercept
|
|
|
|
True
|
|
|
|
|
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
"""
|
|
|
|
|
2017-06-20 02:04:17 -04:00
|
|
|
threshold = Param(Params._dummy(), "threshold",
|
|
|
|
"The threshold in binary classification applied to the linear model"
|
|
|
|
" prediction. This threshold can be any real number, where Inf will make"
|
|
|
|
" all predictions 0.0 and -Inf will make all predictions 1.0.",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
|
|
|
|
2017-01-27 19:03:53 -05:00
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
|
|
|
|
fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
|
|
|
|
aggregationDepth=2):
|
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
|
|
|
|
fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
|
|
|
|
aggregationDepth=2):
|
|
|
|
"""
|
|
|
|
super(LinearSVC, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.LinearSVC", self.uid)
|
|
|
|
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
|
|
|
|
standardization=True, threshold=0.0, aggregationDepth=2)
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2017-01-27 19:03:53 -05:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
@since("2.2.0")
|
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
|
|
|
|
fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
|
|
|
|
aggregationDepth=2):
|
|
|
|
"""
|
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
|
|
|
|
fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
|
|
|
|
aggregationDepth=2):
|
|
|
|
Sets params for Linear SVM Classifier.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2017-01-27 19:03:53 -05:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return LinearSVCModel(java_model)
|
|
|
|
|
|
|
|
|
|
|
|
class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
|
|
|
"""
|
|
|
|
Model fitted by LinearSVC.
|
|
|
|
|
|
|
|
.. versionadded:: 2.2.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.2.0")
|
|
|
|
def coefficients(self):
|
|
|
|
"""
|
|
|
|
Model coefficients of Linear SVM Classifier.
|
|
|
|
"""
|
|
|
|
return self._call_java("coefficients")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.2.0")
|
|
|
|
def intercept(self):
|
|
|
|
"""
|
|
|
|
Model intercept of Linear SVM Classifier.
|
|
|
|
"""
|
|
|
|
return self._call_java("intercept")
|
|
|
|
|
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
@inherit_doc
|
|
|
|
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
2015-09-11 11:50:35 -04:00
|
|
|
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
|
2015-11-18 16:32:06 -05:00
|
|
|
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
|
2016-08-25 05:26:33 -04:00
|
|
|
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
Logistic regression.
|
2016-09-27 03:00:21 -04:00
|
|
|
This class supports multinomial logistic (softmax) and binomial logistic regression.
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> bdf = sc.parallelize([
|
2017-04-26 09:34:18 -04:00
|
|
|
... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),
|
|
|
|
... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
|
|
|
|
... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
|
|
|
|
... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
|
|
|
|
>>> blor = LogisticRegression(regParam=0.01, weightCol="weight")
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel = blor.fit(bdf)
|
|
|
|
>>> blorModel.coefficients
|
2017-04-26 09:34:18 -04:00
|
|
|
DenseVector([-1.080..., -0.646...])
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel.intercept
|
2017-04-26 09:34:18 -04:00
|
|
|
3.112...
|
|
|
|
>>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
|
|
|
|
>>> mdf = spark.read.format("libsvm").load(data_path)
|
|
|
|
>>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> mlorModel = mlor.fit(mdf)
|
2017-04-25 13:10:41 -04:00
|
|
|
>>> mlorModel.coefficientMatrix
|
2017-04-26 09:34:18 -04:00
|
|
|
SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> mlorModel.interceptVector
|
2017-04-26 09:34:18 -04:00
|
|
|
DenseVector([0.04..., -0.42..., 0.37...])
|
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> result = blorModel.transform(test0).head()
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.prediction
|
2017-04-26 09:34:18 -04:00
|
|
|
1.0
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.probability
|
2017-04-26 09:34:18 -04:00
|
|
|
DenseVector([0.02..., 0.97...])
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.rawPrediction
|
2017-04-26 09:34:18 -04:00
|
|
|
DenseVector([-3.54..., 3.54...])
|
|
|
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel.transform(test1).head().prediction
|
2015-01-28 20:14:23 -05:00
|
|
|
1.0
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blor.setParams("vector")
|
2015-02-15 23:29:26 -05:00
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
TypeError: Method setParams forces keyword arguments.
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> lr_path = temp_path + "/lr"
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blor.save(lr_path)
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> lr2 = LogisticRegression.load(lr_path)
|
2017-04-26 09:34:18 -04:00
|
|
|
>>> lr2.getRegParam()
|
|
|
|
0.01
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> model_path = temp_path + "/lr_model"
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel.save(model_path)
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> model2 = LogisticRegressionModel.load(model_path)
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel.coefficients[0] == model2.coefficients[0]
|
2016-03-16 17:21:42 -04:00
|
|
|
True
|
2016-09-27 03:00:21 -04:00
|
|
|
>>> blorModel.intercept == model2.intercept
|
2016-03-16 17:21:42 -04:00
|
|
|
True
|
2018-06-28 15:40:39 -04:00
|
|
|
>>> model2
|
|
|
|
LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.3.0
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
2015-05-18 15:02:18 -04:00
|
|
|
|
2015-08-12 17:27:13 -04:00
|
|
|
threshold = Param(Params._dummy(), "threshold",
|
|
|
|
"Threshold in binary classification prediction, in range [0, 1]." +
|
2016-06-22 05:54:49 -04:00
|
|
|
" If threshold and thresholds are both set, they must match." +
|
|
|
|
"e.g. if threshold is p, then thresholds must be equal to [1-p, p].",
|
2016-03-23 14:20:44 -04:00
|
|
|
typeConverter=TypeConverters.toFloat)
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2016-09-27 03:00:21 -04:00
|
|
|
family = Param(Params._dummy(), "family",
|
|
|
|
"The name of family which is a description of the label distribution to " +
|
|
|
|
"be used in the model. Supported options: auto, binomial, multinomial",
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
|
2017-08-02 06:10:26 -04:00
|
|
|
lowerBoundsOnCoefficients = Param(Params._dummy(), "lowerBoundsOnCoefficients",
|
|
|
|
"The lower bounds on coefficients if fitting under bound "
|
|
|
|
"constrained optimization. The bound matrix must be "
|
|
|
|
"compatible with the shape "
|
|
|
|
"(1, number of features) for binomial regression, or "
|
|
|
|
"(number of classes, number of features) "
|
|
|
|
"for multinomial regression.",
|
|
|
|
typeConverter=TypeConverters.toMatrix)
|
|
|
|
|
|
|
|
upperBoundsOnCoefficients = Param(Params._dummy(), "upperBoundsOnCoefficients",
|
|
|
|
"The upper bounds on coefficients if fitting under bound "
|
|
|
|
"constrained optimization. The bound matrix must be "
|
|
|
|
"compatible with the shape "
|
|
|
|
"(1, number of features) for binomial regression, or "
|
|
|
|
"(number of classes, number of features) "
|
|
|
|
"for multinomial regression.",
|
|
|
|
typeConverter=TypeConverters.toMatrix)
|
|
|
|
|
|
|
|
lowerBoundsOnIntercepts = Param(Params._dummy(), "lowerBoundsOnIntercepts",
|
|
|
|
"The lower bounds on intercepts if fitting under bound "
|
|
|
|
"constrained optimization. The bounds vector size must be"
|
|
|
|
"equal with 1 for binomial regression, or the number of"
|
|
|
|
"lasses for multinomial regression.",
|
|
|
|
typeConverter=TypeConverters.toVector)
|
|
|
|
|
|
|
|
upperBoundsOnIntercepts = Param(Params._dummy(), "upperBoundsOnIntercepts",
|
|
|
|
"The upper bounds on intercepts if fitting under bound "
|
|
|
|
"constrained optimization. The bound vector size must be "
|
|
|
|
"equal with 1 for binomial regression, or the number of "
|
|
|
|
"classes for multinomial regression.",
|
|
|
|
typeConverter=TypeConverters.toVector)
|
|
|
|
|
2015-02-15 23:29:26 -05:00
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability",
|
2016-08-25 05:26:33 -04:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
2017-08-02 06:10:26 -04:00
|
|
|
aggregationDepth=2, family="auto",
|
|
|
|
lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
|
|
|
|
lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
|
|
|
|
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
2016-08-25 05:26:33 -04:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
2017-08-02 06:10:26 -04:00
|
|
|
aggregationDepth=2, family="auto", \
|
|
|
|
lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
|
|
|
|
lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
|
2015-08-12 17:27:13 -04:00
|
|
|
If the threshold and thresholds Params are both set, they must be equivalent.
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
|
|
|
super(LogisticRegression, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
|
2016-09-27 03:00:21 -04:00
|
|
|
self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-02-15 23:29:26 -05:00
|
|
|
self.setParams(**kwargs)
|
2015-08-12 17:27:13 -04:00
|
|
|
self._checkThresholdConsistency()
|
2015-02-15 23:29:26 -05:00
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.3.0")
|
2015-02-15 23:29:26 -05:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability",
|
2016-08-25 05:26:33 -04:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
2017-08-02 06:10:26 -04:00
|
|
|
aggregationDepth=2, family="auto",
|
|
|
|
lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None,
|
|
|
|
lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
2016-08-25 05:26:33 -04:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
2017-08-02 06:10:26 -04:00
|
|
|
aggregationDepth=2, family="auto", \
|
|
|
|
lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
|
|
|
|
lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None):
|
2015-02-15 23:29:26 -05:00
|
|
|
Sets params for logistic regression.
|
2015-08-12 17:27:13 -04:00
|
|
|
If the threshold and thresholds Params are both set, they must be equivalent.
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-08-12 17:27:13 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
self._checkThresholdConsistency()
|
|
|
|
return self
|
2015-02-15 23:29:26 -05:00
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
def _create_model(self, java_model):
|
|
|
|
return LogisticRegressionModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setThreshold(self, value):
|
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
Sets the value of :py:attr:`threshold`.
|
|
|
|
Clears value of :py:attr:`thresholds` if it has been set.
|
|
|
|
"""
|
2016-04-15 15:14:41 -04:00
|
|
|
self._set(threshold=value)
|
|
|
|
self._clear(self.thresholds)
|
2015-08-12 17:27:13 -04:00
|
|
|
return self
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-08-12 17:27:13 -04:00
|
|
|
def getThreshold(self):
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
"""
|
2016-06-22 05:54:49 -04:00
|
|
|
Get threshold for binary classification.
|
|
|
|
|
|
|
|
If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),
|
|
|
|
this returns the equivalent threshold:
|
|
|
|
:math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.
|
|
|
|
Otherwise, returns :py:attr:`threshold` if set or its default value if unset.
|
2015-08-12 17:27:13 -04:00
|
|
|
"""
|
|
|
|
self._checkThresholdConsistency()
|
|
|
|
if self.isSet(self.thresholds):
|
|
|
|
ts = self.getOrDefault(self.thresholds)
|
|
|
|
if len(ts) != 2:
|
|
|
|
raise ValueError("Logistic Regression getThreshold only applies to" +
|
|
|
|
" binary classification, but thresholds has length != 2." +
|
|
|
|
" thresholds: " + ",".join(ts))
|
|
|
|
return 1.0/(1.0 + ts[0]/ts[1])
|
|
|
|
else:
|
|
|
|
return self.getOrDefault(self.threshold)
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
def setThresholds(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`thresholds`.
|
2015-08-12 17:27:13 -04:00
|
|
|
Clears value of :py:attr:`threshold` if it has been set.
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2016-04-15 15:14:41 -04:00
|
|
|
self._set(thresholds=value)
|
|
|
|
self._clear(self.threshold)
|
2015-05-13 18:13:09 -04:00
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
def getThresholds(self):
|
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
If :py:attr:`thresholds` is set, return its value.
|
|
|
|
Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
|
|
|
|
classification: (1-threshold, threshold).
|
|
|
|
If neither are set, throw an error.
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
self._checkThresholdConsistency()
|
|
|
|
if not self.isSet(self.thresholds) and self.isSet(self.threshold):
|
|
|
|
t = self.getOrDefault(self.threshold)
|
|
|
|
return [1.0-t, t]
|
|
|
|
else:
|
|
|
|
return self.getOrDefault(self.thresholds)
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-08-12 17:27:13 -04:00
|
|
|
def _checkThresholdConsistency(self):
|
|
|
|
if self.isSet(self.threshold) and self.isSet(self.thresholds):
|
2017-05-10 04:57:52 -04:00
|
|
|
ts = self.getOrDefault(self.thresholds)
|
2015-08-12 17:27:13 -04:00
|
|
|
if len(ts) != 2:
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
raise ValueError("Logistic Regression getThreshold only applies to" +
|
|
|
|
" binary classification, but thresholds has length != 2." +
|
2017-05-10 04:57:52 -04:00
|
|
|
" thresholds: {0}".format(str(ts)))
|
2015-08-12 17:27:13 -04:00
|
|
|
t = 1.0/(1.0 + ts[0]/ts[1])
|
2017-05-10 04:57:52 -04:00
|
|
|
t2 = self.getOrDefault(self.threshold)
|
2015-08-12 17:27:13 -04:00
|
|
|
if abs(t2 - t) >= 1E-5:
|
|
|
|
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
|
|
|
|
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2016-09-27 03:00:21 -04:00
|
|
|
@since("2.1.0")
|
|
|
|
def setFamily(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`family`.
|
|
|
|
"""
|
|
|
|
return self._set(family=value)
|
|
|
|
|
|
|
|
@since("2.1.0")
|
|
|
|
def getFamily(self):
|
|
|
|
"""
|
|
|
|
Gets the value of :py:attr:`family` or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.family)
|
|
|
|
|
2017-08-02 06:10:26 -04:00
|
|
|
@since("2.3.0")
|
|
|
|
def setLowerBoundsOnCoefficients(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`lowerBoundsOnCoefficients`
|
|
|
|
"""
|
|
|
|
return self._set(lowerBoundsOnCoefficients=value)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def getLowerBoundsOnCoefficients(self):
|
|
|
|
"""
|
|
|
|
Gets the value of :py:attr:`lowerBoundsOnCoefficients`
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.lowerBoundsOnCoefficients)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def setUpperBoundsOnCoefficients(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`upperBoundsOnCoefficients`
|
|
|
|
"""
|
|
|
|
return self._set(upperBoundsOnCoefficients=value)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def getUpperBoundsOnCoefficients(self):
|
|
|
|
"""
|
|
|
|
Gets the value of :py:attr:`upperBoundsOnCoefficients`
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.upperBoundsOnCoefficients)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def setLowerBoundsOnIntercepts(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`lowerBoundsOnIntercepts`
|
|
|
|
"""
|
|
|
|
return self._set(lowerBoundsOnIntercepts=value)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def getLowerBoundsOnIntercepts(self):
|
|
|
|
"""
|
|
|
|
Gets the value of :py:attr:`lowerBoundsOnIntercepts`
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.lowerBoundsOnIntercepts)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def setUpperBoundsOnIntercepts(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`upperBoundsOnIntercepts`
|
|
|
|
"""
|
|
|
|
return self._set(upperBoundsOnIntercepts=value)
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def getUpperBoundsOnIntercepts(self):
|
|
|
|
"""
|
|
|
|
Gets the value of :py:attr:`upperBoundsOnIntercepts`
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.upperBoundsOnIntercepts)
|
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2019-02-01 18:29:58 -05:00
|
|
|
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable,
|
|
|
|
HasTrainingSummary):
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
Model fitted by LogisticRegression.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.3.0
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
|
2015-11-02 19:12:04 -05:00
|
|
|
@property
|
2016-06-22 13:05:25 -04:00
|
|
|
@since("2.0.0")
|
2015-11-02 19:12:04 -05:00
|
|
|
def coefficients(self):
|
|
|
|
"""
|
2016-09-27 03:00:21 -04:00
|
|
|
Model coefficients of binomial logistic regression.
|
|
|
|
An exception is thrown in the case of multinomial logistic regression.
|
2015-11-02 19:12:04 -05:00
|
|
|
"""
|
|
|
|
return self._call_java("coefficients")
|
|
|
|
|
2015-05-14 21:13:58 -04:00
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-14 21:13:58 -04:00
|
|
|
def intercept(self):
|
|
|
|
"""
|
2016-09-27 03:00:21 -04:00
|
|
|
Model intercept of binomial logistic regression.
|
|
|
|
An exception is thrown in the case of multinomial logistic regression.
|
2015-05-14 21:13:58 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("intercept")
|
|
|
|
|
2016-09-27 03:00:21 -04:00
|
|
|
@property
|
|
|
|
@since("2.1.0")
|
|
|
|
def coefficientMatrix(self):
|
|
|
|
"""
|
|
|
|
Model coefficients.
|
|
|
|
"""
|
|
|
|
return self._call_java("coefficientMatrix")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.1.0")
|
|
|
|
def interceptVector(self):
|
|
|
|
"""
|
|
|
|
Model intercept.
|
|
|
|
"""
|
|
|
|
return self._call_java("interceptVector")
|
|
|
|
|
2016-04-06 15:07:47 -04:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def summary(self):
|
|
|
|
"""
|
2016-11-21 08:36:49 -05:00
|
|
|
Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
|
|
|
|
trained on the training set. An exception is thrown if `trainingSummary is None`.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
2016-11-21 08:36:49 -05:00
|
|
|
if self.hasSummary:
|
2017-09-14 01:53:28 -04:00
|
|
|
if self.numClasses <= 2:
|
2019-02-01 18:29:58 -05:00
|
|
|
return BinaryLogisticRegressionTrainingSummary(super(LogisticRegressionModel,
|
|
|
|
self).summary)
|
2017-09-14 01:53:28 -04:00
|
|
|
else:
|
2019-02-01 18:29:58 -05:00
|
|
|
return LogisticRegressionTrainingSummary(super(LogisticRegressionModel,
|
|
|
|
self).summary)
|
2016-11-21 08:36:49 -05:00
|
|
|
else:
|
|
|
|
raise RuntimeError("No training summary available for this %s" %
|
|
|
|
self.__class__.__name__)
|
2016-04-06 15:07:47 -04:00
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def evaluate(self, dataset):
|
|
|
|
"""
|
|
|
|
Evaluates the model on a test dataset.
|
|
|
|
|
|
|
|
:param dataset:
|
|
|
|
Test dataset to evaluate model on, where dataset is an
|
|
|
|
instance of :py:class:`pyspark.sql.DataFrame`
|
|
|
|
"""
|
|
|
|
if not isinstance(dataset, DataFrame):
|
|
|
|
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
|
|
|
|
java_blr_summary = self._call_java("evaluate", dataset)
|
|
|
|
return BinaryLogisticRegressionSummary(java_blr_summary)
|
|
|
|
|
2018-06-28 15:40:39 -04:00
|
|
|
def __repr__(self):
|
|
|
|
return self._call_java("toString")
|
|
|
|
|
2016-04-06 15:07:47 -04:00
|
|
|
|
2016-04-13 17:08:57 -04:00
|
|
|
class LogisticRegressionSummary(JavaWrapper):
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
Abstraction for Logistic Regression Results for a given model.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def predictions(self):
|
|
|
|
"""
|
|
|
|
Dataframe outputted by the model's `transform` method.
|
|
|
|
"""
|
|
|
|
return self._call_java("predictions")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def probabilityCol(self):
|
|
|
|
"""
|
2016-04-08 23:15:44 -04:00
|
|
|
Field in "predictions" which gives the probability
|
2016-04-06 15:07:47 -04:00
|
|
|
of each class as a vector.
|
|
|
|
"""
|
|
|
|
return self._call_java("probabilityCol")
|
|
|
|
|
2017-09-14 01:53:28 -04:00
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def predictionCol(self):
|
|
|
|
"""
|
|
|
|
Field in "predictions" which gives the prediction of each class.
|
|
|
|
"""
|
|
|
|
return self._call_java("predictionCol")
|
|
|
|
|
2016-04-06 15:07:47 -04:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def labelCol(self):
|
|
|
|
"""
|
|
|
|
Field in "predictions" which gives the true label of each
|
|
|
|
instance.
|
|
|
|
"""
|
|
|
|
return self._call_java("labelCol")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featuresCol(self):
|
|
|
|
"""
|
|
|
|
Field in "predictions" which gives the features of each instance
|
|
|
|
as a vector.
|
|
|
|
"""
|
|
|
|
return self._call_java("featuresCol")
|
|
|
|
|
2017-09-14 01:53:28 -04:00
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def labels(self):
|
|
|
|
"""
|
|
|
|
Returns the sequence of labels in ascending order. This order matches the order used
|
|
|
|
in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
|
|
|
|
|
|
|
|
Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
|
|
|
|
training set is missing a label, then all of the arrays over labels
|
|
|
|
(e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
|
|
|
|
expected numClasses.
|
|
|
|
"""
|
|
|
|
return self._call_java("labels")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def truePositiveRateByLabel(self):
|
|
|
|
"""
|
|
|
|
Returns true positive rate for each label (category).
|
|
|
|
"""
|
|
|
|
return self._call_java("truePositiveRateByLabel")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def falsePositiveRateByLabel(self):
|
|
|
|
"""
|
|
|
|
Returns false positive rate for each label (category).
|
|
|
|
"""
|
|
|
|
return self._call_java("falsePositiveRateByLabel")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def precisionByLabel(self):
|
|
|
|
"""
|
|
|
|
Returns precision for each label (category).
|
|
|
|
"""
|
|
|
|
return self._call_java("precisionByLabel")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def recallByLabel(self):
|
|
|
|
"""
|
|
|
|
Returns recall for each label (category).
|
|
|
|
"""
|
|
|
|
return self._call_java("recallByLabel")
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def fMeasureByLabel(self, beta=1.0):
|
|
|
|
"""
|
|
|
|
Returns f-measure for each label (category).
|
|
|
|
"""
|
|
|
|
return self._call_java("fMeasureByLabel", beta)
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def accuracy(self):
|
|
|
|
"""
|
|
|
|
Returns accuracy.
|
|
|
|
(equals to the total number of correctly classified instances
|
|
|
|
out of the total number of instances.)
|
|
|
|
"""
|
|
|
|
return self._call_java("accuracy")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def weightedTruePositiveRate(self):
|
|
|
|
"""
|
|
|
|
Returns weighted true positive rate.
|
|
|
|
(equals to precision, recall and f-measure)
|
|
|
|
"""
|
|
|
|
return self._call_java("weightedTruePositiveRate")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def weightedFalsePositiveRate(self):
|
|
|
|
"""
|
|
|
|
Returns weighted false positive rate.
|
|
|
|
"""
|
|
|
|
return self._call_java("weightedFalsePositiveRate")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def weightedRecall(self):
|
|
|
|
"""
|
|
|
|
Returns weighted averaged recall.
|
|
|
|
(equals to precision, recall and f-measure)
|
|
|
|
"""
|
|
|
|
return self._call_java("weightedRecall")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.3.0")
|
|
|
|
def weightedPrecision(self):
|
|
|
|
"""
|
|
|
|
Returns weighted averaged precision.
|
|
|
|
"""
|
|
|
|
return self._call_java("weightedPrecision")
|
|
|
|
|
|
|
|
@since("2.3.0")
|
|
|
|
def weightedFMeasure(self, beta=1.0):
|
|
|
|
"""
|
|
|
|
Returns weighted averaged f-measure.
|
|
|
|
"""
|
|
|
|
return self._call_java("weightedFMeasure", beta)
|
|
|
|
|
2016-04-06 15:07:47 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
|
|
|
|
"""
|
|
|
|
Abstraction for multinomial Logistic Regression Training results.
|
|
|
|
Currently, the training summary ignores the training weights except
|
|
|
|
for the objective trace.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def objectiveHistory(self):
|
|
|
|
"""
|
|
|
|
Objective function (scaled loss + regularization) at each
|
|
|
|
iteration.
|
|
|
|
"""
|
|
|
|
return self._call_java("objectiveHistory")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def totalIterations(self):
|
|
|
|
"""
|
|
|
|
Number of training iterations until termination.
|
|
|
|
"""
|
|
|
|
return self._call_java("totalIterations")
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
|
|
|
|
"""
|
|
|
|
Binary Logistic regression results for a given model.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def roc(self):
|
|
|
|
"""
|
|
|
|
Returns the receiver operating characteristic (ROC) curve,
|
2016-06-06 04:35:47 -04:00
|
|
|
which is a Dataframe having two fields (FPR, TPR) with
|
2016-04-06 15:07:47 -04:00
|
|
|
(0.0, 0.0) prepended and (1.0, 1.0) appended to it.
|
2016-05-09 04:11:17 -04:00
|
|
|
|
2018-09-12 23:19:43 -04:00
|
|
|
.. seealso:: `Wikipedia reference
|
|
|
|
<http://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
|
2016-04-06 15:07:47 -04:00
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("roc")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def areaUnderROC(self):
|
|
|
|
"""
|
|
|
|
Computes the area under the receiver operating characteristic
|
|
|
|
(ROC) curve.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("areaUnderROC")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def pr(self):
|
|
|
|
"""
|
2016-06-06 04:35:47 -04:00
|
|
|
Returns the precision-recall curve, which is a Dataframe
|
2016-04-06 15:07:47 -04:00
|
|
|
containing two fields recall, precision with (0.0, 1.0) prepended
|
|
|
|
to it.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("pr")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def fMeasureByThreshold(self):
|
|
|
|
"""
|
|
|
|
Returns a dataframe with two fields (threshold, F-Measure) curve
|
|
|
|
with beta = 1.0.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("fMeasureByThreshold")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def precisionByThreshold(self):
|
|
|
|
"""
|
|
|
|
Returns a dataframe with two fields (threshold, precision) curve.
|
|
|
|
Every possible probability obtained in transforming the dataset
|
|
|
|
are used as thresholds used in calculating the precision.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("precisionByThreshold")
|
|
|
|
|
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def recallByThreshold(self):
|
|
|
|
"""
|
|
|
|
Returns a dataframe with two fields (threshold, recall) curve.
|
|
|
|
Every possible probability obtained in transforming the dataset
|
|
|
|
are used as thresholds used in calculating the recall.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: This ignores instance weights (setting all to 1.0) from
|
|
|
|
`LogisticRegression.weightCol`. This will change in later Spark
|
|
|
|
versions.
|
2016-04-06 15:07:47 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("recallByThreshold")
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
|
|
|
|
LogisticRegressionTrainingSummary):
|
|
|
|
"""
|
|
|
|
Binary Logistic regression training results for a given model.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
class TreeClassifierParams(object):
|
|
|
|
"""
|
|
|
|
Private class to track supported impurity measures.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
supportedImpurities = ["entropy", "gini"]
|
|
|
|
|
2015-10-27 16:55:03 -04:00
|
|
|
impurity = Param(Params._dummy(), "impurity",
|
|
|
|
"Criterion used for information gain calculation (case-insensitive). " +
|
|
|
|
"Supported options: " +
|
2016-03-23 14:20:44 -04:00
|
|
|
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
2015-10-27 16:55:03 -04:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(TreeClassifierParams, self).__init__()
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-10-27 16:55:03 -04:00
|
|
|
def getImpurity(self):
|
|
|
|
"""
|
|
|
|
Gets the value of impurity or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.impurity)
|
|
|
|
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
@inherit_doc
|
2019-02-27 22:11:30 -05:00
|
|
|
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasWeightCol,
|
|
|
|
HasPredictionCol, HasProbabilityCol, HasRawPredictionCol,
|
|
|
|
DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval,
|
|
|
|
HasSeed, JavaMLWritable, JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2016-05-09 04:11:17 -04:00
|
|
|
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
2015-05-13 18:13:09 -04:00
|
|
|
learning algorithm for classification.
|
|
|
|
It supports both binary and multiclass labels, as well as both continuous and categorical
|
|
|
|
features.
|
|
|
|
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame([
|
2015-05-13 18:13:09 -04:00
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId")
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model = dt.fit(td)
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> model.numNodes
|
|
|
|
3
|
|
|
|
>>> model.depth
|
|
|
|
1
|
2016-03-11 02:54:23 -05:00
|
|
|
>>> model.featureImportances
|
|
|
|
SparseVector(1, {0: 1.0})
|
2016-08-22 06:21:22 -04:00
|
|
|
>>> model.numFeatures
|
|
|
|
1
|
|
|
|
>>> model.numClasses
|
|
|
|
2
|
2016-06-02 18:55:14 -04:00
|
|
|
>>> print(model.toDebugString)
|
|
|
|
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-05-13 18:13:09 -04:00
|
|
|
0.0
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.probability
|
|
|
|
DenseVector([1.0, 0.0])
|
|
|
|
>>> result.rawPrediction
|
|
|
|
DenseVector([1.0, 0.0])
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> result.leafId
|
|
|
|
0.0
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2016-03-24 22:20:49 -04:00
|
|
|
>>> dtc_path = temp_path + "/dtc"
|
|
|
|
>>> dt.save(dtc_path)
|
|
|
|
>>> dt2 = DecisionTreeClassifier.load(dtc_path)
|
|
|
|
>>> dt2.getMaxDepth()
|
|
|
|
2
|
|
|
|
>>> model_path = temp_path + "/dtc_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = DecisionTreeClassificationModel.load(model_path)
|
|
|
|
>>> model.featureImportances == model2.featureImportances
|
|
|
|
True
|
|
|
|
|
2019-02-27 22:11:30 -05:00
|
|
|
>>> df3 = spark.createDataFrame([
|
|
|
|
... (1.0, 0.2, Vectors.dense(1.0)),
|
|
|
|
... (1.0, 0.8, Vectors.dense(1.0)),
|
|
|
|
... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
|
|
|
|
>>> si3 = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model3 = si3.fit(df3)
|
|
|
|
>>> td3 = si_model3.transform(df3)
|
|
|
|
>>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed")
|
|
|
|
>>> model3 = dt3.fit(td3)
|
|
|
|
>>> print(model3.toDebugString)
|
|
|
|
DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
2019-08-23 18:18:35 -04:00
|
|
|
seed=None, weightCol=None, leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
2019-08-23 18:18:35 -04:00
|
|
|
seed=None, weightCol=None, leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(DecisionTreeClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2019-08-23 18:18:35 -04:00
|
|
|
impurity="gini", leafCol="")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2019-08-23 18:18:35 -04:00
|
|
|
impurity="gini", seed=None, weightCol=None, leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
2019-08-23 18:18:35 -04:00
|
|
|
seed=None, weightCol=None, leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
Sets params for the DecisionTreeClassifier.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return DecisionTreeClassificationModel(java_model)
|
|
|
|
|
2019-07-20 11:44:33 -04:00
|
|
|
def setMaxDepth(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxDepth`.
|
|
|
|
"""
|
|
|
|
return self._set(maxDepth=value)
|
|
|
|
|
|
|
|
def setMaxBins(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxBins`.
|
|
|
|
"""
|
|
|
|
return self._set(maxBins=value)
|
|
|
|
|
|
|
|
def setMinInstancesPerNode(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInstancesPerNode`.
|
|
|
|
"""
|
|
|
|
return self._set(minInstancesPerNode=value)
|
|
|
|
|
|
|
|
def setMinInfoGain(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInfoGain`.
|
|
|
|
"""
|
|
|
|
return self._set(minInfoGain=value)
|
|
|
|
|
|
|
|
def setMaxMemoryInMB(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxMemoryInMB`.
|
|
|
|
"""
|
|
|
|
return self._set(maxMemoryInMB=value)
|
|
|
|
|
|
|
|
def setCacheNodeIds(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`cacheNodeIds`.
|
|
|
|
"""
|
|
|
|
return self._set(cacheNodeIds=value)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def setImpurity(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`impurity`.
|
|
|
|
"""
|
|
|
|
return self._set(impurity=value)
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
@inherit_doc
|
2016-08-22 06:21:22 -04:00
|
|
|
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
|
|
|
|
JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by DecisionTreeClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
2016-03-11 02:54:23 -05:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featureImportances(self):
|
|
|
|
"""
|
|
|
|
Estimate of the importance of each feature.
|
|
|
|
|
|
|
|
This generalizes the idea of "Gini" importance to other losses,
|
|
|
|
following the explanation of Gini importance from "Random Forests" documentation
|
|
|
|
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
|
|
|
|
|
|
|
This feature importance is calculated as follows:
|
|
|
|
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
|
|
|
where gain is scaled by the number of instances passing through node
|
|
|
|
- Normalize importances for tree to sum to 1.
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: Feature importance for single decision trees can have high variance due to
|
|
|
|
correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
|
|
|
|
to determine feature importance instead.
|
2016-03-11 02:54:23 -05:00
|
|
|
"""
|
|
|
|
return self._call_java("featureImportances")
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
|
2015-08-04 17:54:26 -04:00
|
|
|
HasRawPredictionCol, HasProbabilityCol,
|
2016-04-08 13:39:12 -04:00
|
|
|
RandomForestParams, TreeClassifierParams, HasCheckpointInterval,
|
|
|
|
JavaMLWritable, JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2016-05-09 04:11:17 -04:00
|
|
|
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
2015-05-13 18:13:09 -04:00
|
|
|
learning algorithm for classification.
|
|
|
|
It supports both binary and multiclass labels, as well as both continuous and categorical
|
|
|
|
features.
|
|
|
|
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> import numpy
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> from numpy import allclose
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame([
|
2015-05-13 18:13:09 -04:00
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
|
|
|
|
... leafCol="leafId")
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model = rf.fit(td)
|
2016-03-11 02:54:23 -05:00
|
|
|
>>> model.featureImportances
|
|
|
|
SparseVector(1, {0: 1.0})
|
2015-07-29 21:18:29 -04:00
|
|
|
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
|
2015-07-07 11:58:08 -04:00
|
|
|
True
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-05-13 18:13:09 -04:00
|
|
|
0.0
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> numpy.argmax(result.probability)
|
|
|
|
0
|
|
|
|
>>> numpy.argmax(result.rawPrediction)
|
|
|
|
0
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> result.leafId
|
|
|
|
DenseVector([0.0, 0.0, 0.0])
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2016-06-02 18:55:14 -04:00
|
|
|
>>> model.trees
|
|
|
|
[DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
|
2016-04-08 13:39:12 -04:00
|
|
|
>>> rfc_path = temp_path + "/rfc"
|
|
|
|
>>> rf.save(rfc_path)
|
|
|
|
>>> rf2 = RandomForestClassifier.load(rfc_path)
|
|
|
|
>>> rf2.getNumTrees()
|
|
|
|
3
|
|
|
|
>>> model_path = temp_path + "/rfc_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = RandomForestClassificationModel.load(model_path)
|
|
|
|
>>> model.featureImportances == model2.featureImportances
|
|
|
|
True
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
2019-08-23 18:18:35 -04:00
|
|
|
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
|
|
|
|
leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
2019-08-23 18:18:35 -04:00
|
|
|
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
|
|
|
|
leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(RandomForestClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2016-06-21 14:43:25 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2016-10-30 19:21:37 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
|
2019-08-23 18:18:35 -04:00
|
|
|
subsamplingRate=1.0, leafCol="")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2015-05-20 18:16:12 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
|
2019-08-23 18:18:35 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
|
|
|
|
leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2015-05-20 18:16:12 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
|
2019-08-23 18:18:35 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
|
|
|
|
leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
Sets params for linear classification.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return RandomForestClassificationModel(java_model)
|
|
|
|
|
2019-07-20 11:44:33 -04:00
|
|
|
def setMaxDepth(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxDepth`.
|
|
|
|
"""
|
|
|
|
return self._set(maxDepth=value)
|
|
|
|
|
|
|
|
def setMaxBins(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxBins`.
|
|
|
|
"""
|
|
|
|
return self._set(maxBins=value)
|
|
|
|
|
|
|
|
def setMinInstancesPerNode(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInstancesPerNode`.
|
|
|
|
"""
|
|
|
|
return self._set(minInstancesPerNode=value)
|
|
|
|
|
|
|
|
def setMinInfoGain(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInfoGain`.
|
|
|
|
"""
|
|
|
|
return self._set(minInfoGain=value)
|
|
|
|
|
|
|
|
def setMaxMemoryInMB(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxMemoryInMB`.
|
|
|
|
"""
|
|
|
|
return self._set(maxMemoryInMB=value)
|
|
|
|
|
|
|
|
def setCacheNodeIds(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`cacheNodeIds`.
|
|
|
|
"""
|
|
|
|
return self._set(cacheNodeIds=value)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def setImpurity(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`impurity`.
|
|
|
|
"""
|
|
|
|
return self._set(impurity=value)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def setNumTrees(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`numTrees`.
|
|
|
|
"""
|
|
|
|
return self._set(numTrees=value)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def setSubsamplingRate(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`subsamplingRate`.
|
|
|
|
"""
|
|
|
|
return self._set(subsamplingRate=value)
|
|
|
|
|
2018-05-30 14:04:09 -04:00
|
|
|
@since("2.4.0")
|
|
|
|
def setFeatureSubsetStrategy(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`featureSubsetStrategy`.
|
|
|
|
"""
|
|
|
|
return self._set(featureSubsetStrategy=value)
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2016-08-22 06:21:22 -04:00
|
|
|
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
|
|
|
|
JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by RandomForestClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
2016-03-11 02:54:23 -05:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featureImportances(self):
|
|
|
|
"""
|
|
|
|
Estimate of the importance of each feature.
|
|
|
|
|
2016-03-31 16:00:10 -04:00
|
|
|
Each feature's importance is the average of its importance across all trees in the ensemble
|
|
|
|
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
|
|
|
|
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
|
|
|
|
and follows the implementation from scikit-learn.
|
2016-03-11 02:54:23 -05:00
|
|
|
|
2016-03-31 16:00:10 -04:00
|
|
|
.. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
|
2016-03-11 02:54:23 -05:00
|
|
|
"""
|
|
|
|
return self._call_java("featureImportances")
|
|
|
|
|
2016-06-02 18:55:14 -04:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def trees(self):
|
|
|
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
|
|
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2018-12-07 16:53:35 -05:00
|
|
|
class GBTClassifierParams(GBTParams, HasVarianceImpurity):
|
|
|
|
"""
|
|
|
|
Private class to track supported GBTClassifier params.
|
|
|
|
|
|
|
|
.. versionadded:: 3.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
supportedLossTypes = ["logistic"]
|
|
|
|
|
|
|
|
lossType = Param(Params._dummy(), "lossType",
|
|
|
|
"Loss function which GBT tries to minimize (case-insensitive). " +
|
|
|
|
"Supported options: " + ", ".join(supportedLossTypes),
|
|
|
|
typeConverter=TypeConverters.toString)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def getLossType(self):
|
|
|
|
"""
|
|
|
|
Gets the value of lossType or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.lossType)
|
|
|
|
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
@inherit_doc
|
2018-12-07 16:53:35 -05:00
|
|
|
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|
|
|
GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
|
2016-04-15 00:36:03 -04:00
|
|
|
JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2016-05-09 04:11:17 -04:00
|
|
|
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
2015-05-13 18:13:09 -04:00
|
|
|
learning algorithm for classification.
|
|
|
|
It supports binary labels, as well as both continuous and categorical features.
|
|
|
|
|
2016-05-09 04:11:17 -04:00
|
|
|
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
|
|
|
|
|
|
|
|
Notes on Gradient Boosting vs. TreeBoost:
|
|
|
|
- This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
|
|
|
|
- Both algorithms learn tree ensembles by minimizing loss functions.
|
|
|
|
- TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
|
|
|
|
based on the loss function, whereas the original gradient boosting method does not.
|
|
|
|
- We expect to implement TreeBoost in the future:
|
|
|
|
`SPARK-4240 <https://issues.apache.org/jira/browse/SPARK-4240>`_
|
|
|
|
|
2016-11-22 06:40:18 -05:00
|
|
|
.. note:: Multiclass labels are not currently supported.
|
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> from numpy import allclose
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame([
|
2015-05-13 18:13:09 -04:00
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
|
|
|
|
... leafCol="leafId")
|
2018-05-30 14:04:09 -04:00
|
|
|
>>> gbt.getFeatureSubsetStrategy()
|
|
|
|
'all'
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model = gbt.fit(td)
|
2016-03-31 16:00:10 -04:00
|
|
|
>>> model.featureImportances
|
|
|
|
SparseVector(1, {0: 1.0})
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
|
|
|
|
True
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-05-13 18:13:09 -04:00
|
|
|
0.0
|
2019-08-23 18:18:35 -04:00
|
|
|
>>> result.leafId
|
|
|
|
DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2016-06-02 18:55:14 -04:00
|
|
|
>>> model.totalNumNodes
|
|
|
|
15
|
|
|
|
>>> print(model.toDebugString)
|
|
|
|
GBTClassificationModel (uid=...)...with 5 trees...
|
2016-04-15 00:36:03 -04:00
|
|
|
>>> gbtc_path = temp_path + "gbtc"
|
|
|
|
>>> gbt.save(gbtc_path)
|
|
|
|
>>> gbt2 = GBTClassifier.load(gbtc_path)
|
|
|
|
>>> gbt2.getMaxDepth()
|
|
|
|
2
|
|
|
|
>>> model_path = temp_path + "gbtc_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = GBTClassificationModel.load(model_path)
|
|
|
|
>>> model.featureImportances == model2.featureImportances
|
|
|
|
True
|
|
|
|
>>> model.treeWeights == model2.treeWeights
|
|
|
|
True
|
2016-06-20 19:28:11 -04:00
|
|
|
>>> model.trees
|
|
|
|
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
|
2018-05-15 17:16:31 -04:00
|
|
|
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
|
|
|
|
... ["indexed", "features"])
|
|
|
|
>>> model.evaluateEachIteration(validation)
|
|
|
|
[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
|
2018-05-30 14:04:09 -04:00
|
|
|
>>> model.numClasses
|
|
|
|
2
|
2018-12-07 16:53:35 -05:00
|
|
|
>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
|
|
|
|
>>> gbt.getValidationIndicatorCol()
|
|
|
|
'validationIndicator'
|
|
|
|
>>> gbt.getValidationTol()
|
|
|
|
0.01
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
|
2018-12-07 16:53:35 -05:00
|
|
|
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
|
2019-08-23 18:18:35 -04:00
|
|
|
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
|
|
|
|
leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
2018-05-30 14:04:09 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
2018-12-07 16:53:35 -05:00
|
|
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
2019-08-23 18:18:35 -04:00
|
|
|
validationIndicatorCol=None, leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(GBTClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2018-05-30 14:04:09 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
|
2019-08-23 18:18:35 -04:00
|
|
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
|
|
|
leafCol="")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2018-05-30 14:04:09 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
|
2018-12-07 16:53:35 -05:00
|
|
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
2019-08-23 18:18:35 -04:00
|
|
|
validationIndicatorCol=None, leafCol=""):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
2018-05-30 14:04:09 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
2018-12-07 16:53:35 -05:00
|
|
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
2019-08-23 18:18:35 -04:00
|
|
|
validationIndicatorCol=None, leafCol="")
|
2015-05-13 18:13:09 -04:00
|
|
|
Sets params for Gradient Boosted Tree Classification.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-05-13 18:13:09 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return GBTClassificationModel(java_model)
|
|
|
|
|
2019-07-20 11:44:33 -04:00
|
|
|
def setMaxDepth(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxDepth`.
|
|
|
|
"""
|
|
|
|
return self._set(maxDepth=value)
|
|
|
|
|
|
|
|
def setMaxBins(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxBins`.
|
|
|
|
"""
|
|
|
|
return self._set(maxBins=value)
|
|
|
|
|
|
|
|
def setMinInstancesPerNode(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInstancesPerNode`.
|
|
|
|
"""
|
|
|
|
return self._set(minInstancesPerNode=value)
|
|
|
|
|
|
|
|
def setMinInfoGain(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`minInfoGain`.
|
|
|
|
"""
|
|
|
|
return self._set(minInfoGain=value)
|
|
|
|
|
|
|
|
def setMaxMemoryInMB(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`maxMemoryInMB`.
|
|
|
|
"""
|
|
|
|
return self._set(maxMemoryInMB=value)
|
|
|
|
|
|
|
|
def setCacheNodeIds(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`cacheNodeIds`.
|
|
|
|
"""
|
|
|
|
return self._set(cacheNodeIds=value)
|
|
|
|
|
|
|
|
@since("1.4.0")
|
|
|
|
def setImpurity(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`impurity`.
|
|
|
|
"""
|
|
|
|
return self._set(impurity=value)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setLossType(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`lossType`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(lossType=value)
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2019-07-20 11:44:33 -04:00
|
|
|
@since("1.4.0")
|
|
|
|
def setSubsamplingRate(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`subsamplingRate`.
|
|
|
|
"""
|
|
|
|
return self._set(subsamplingRate=value)
|
|
|
|
|
2018-05-30 14:04:09 -04:00
|
|
|
@since("2.4.0")
|
|
|
|
def setFeatureSubsetStrategy(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`featureSubsetStrategy`.
|
|
|
|
"""
|
|
|
|
return self._set(featureSubsetStrategy=value)
|
|
|
|
|
2018-12-07 16:53:35 -05:00
|
|
|
@since("3.0.0")
|
|
|
|
def setValidationIndicatorCol(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`validationIndicatorCol`.
|
|
|
|
"""
|
|
|
|
return self._set(validationIndicatorCol=value)
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2018-05-30 14:04:09 -04:00
|
|
|
class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
|
2016-08-22 06:21:22 -04:00
|
|
|
JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by GBTClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
2016-03-31 16:00:10 -04:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featureImportances(self):
|
|
|
|
"""
|
|
|
|
Estimate of the importance of each feature.
|
|
|
|
|
|
|
|
Each feature's importance is the average of its importance across all trees in the ensemble
|
|
|
|
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
|
|
|
|
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
|
|
|
|
and follows the implementation from scikit-learn.
|
|
|
|
|
|
|
|
.. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
|
|
|
|
"""
|
|
|
|
return self._call_java("featureImportances")
|
|
|
|
|
2016-06-02 18:55:14 -04:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def trees(self):
|
|
|
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
|
|
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
|
|
|
|
|
2018-05-15 17:16:31 -04:00
|
|
|
@since("2.4.0")
|
|
|
|
def evaluateEachIteration(self, dataset):
|
|
|
|
"""
|
|
|
|
Method to compute error or loss for every iteration of gradient boosting.
|
|
|
|
|
|
|
|
:param dataset:
|
|
|
|
Test dataset to evaluate model on, where dataset is an
|
|
|
|
instance of :py:class:`pyspark.sql.DataFrame`
|
|
|
|
"""
|
|
|
|
return self._call_java("evaluateEachIteration", dataset)
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2015-07-31 02:03:48 -04:00
|
|
|
@inherit_doc
|
2015-08-03 01:19:27 -04:00
|
|
|
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
|
2016-10-12 22:52:57 -04:00
|
|
|
HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
Naive Bayes Classifiers.
|
2016-05-09 04:11:17 -04:00
|
|
|
It supports both Multinomial and Bernoulli NB. `Multinomial NB
|
|
|
|
<http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html>`_
|
2015-08-12 16:24:18 -04:00
|
|
|
can handle finitely supported discrete data. For example, by converting documents into
|
|
|
|
TF-IDF vectors, it can be used for document classification. By making every vector a
|
2016-05-09 04:11:17 -04:00
|
|
|
binary (0/1) data, it can also be used as `Bernoulli NB
|
|
|
|
<http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html>`_.
|
2015-08-12 16:24:18 -04:00
|
|
|
The input feature values must be nonnegative.
|
2015-07-31 02:03:48 -04:00
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame([
|
2016-10-12 22:52:57 -04:00
|
|
|
... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
|
|
|
|
... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
|
|
|
|
... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
|
|
|
|
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
|
2015-07-31 02:03:48 -04:00
|
|
|
>>> model = nb.fit(df)
|
|
|
|
>>> model.pi
|
2016-10-12 22:52:57 -04:00
|
|
|
DenseVector([-0.81..., -0.58...])
|
2015-07-31 02:03:48 -04:00
|
|
|
>>> model.theta
|
2016-10-12 22:52:57 -04:00
|
|
|
DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
|
2015-07-31 02:03:48 -04:00
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-07-31 02:03:48 -04:00
|
|
|
1.0
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.probability
|
2016-10-12 22:52:57 -04:00
|
|
|
DenseVector([0.32..., 0.67...])
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.rawPrediction
|
2016-10-12 22:52:57 -04:00
|
|
|
DenseVector([-1.72..., -0.99...])
|
2015-07-31 02:03:48 -04:00
|
|
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
|
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> nb_path = temp_path + "/nb"
|
|
|
|
>>> nb.save(nb_path)
|
|
|
|
>>> nb2 = NaiveBayes.load(nb_path)
|
|
|
|
>>> nb2.getSmoothing()
|
|
|
|
1.0
|
|
|
|
>>> model_path = temp_path + "/nb_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = NaiveBayesModel.load(model_path)
|
|
|
|
>>> model.pi == model2.pi
|
|
|
|
True
|
|
|
|
>>> model.theta == model2.theta
|
|
|
|
True
|
2016-05-13 02:39:59 -04:00
|
|
|
>>> nb = nb.setThresholds([0.01, 10.00])
|
|
|
|
>>> model3 = nb.fit(df)
|
|
|
|
>>> result = model3.transform(test0).head()
|
|
|
|
>>> result.prediction
|
|
|
|
0.0
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.5.0
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"default is 1.0", typeConverter=TypeConverters.toFloat)
|
2015-07-31 02:03:48 -04:00
|
|
|
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"(case-sensitive). Supported options: multinomial (default) and bernoulli.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2015-07-31 02:03:48 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
|
2016-10-12 22:52:57 -04:00
|
|
|
modelType="multinomial", thresholds=None, weightCol=None):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
2015-08-03 01:19:27 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
|
2016-10-12 22:52:57 -04:00
|
|
|
modelType="multinomial", thresholds=None, weightCol=None)
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
super(NaiveBayes, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.NaiveBayes", self.uid)
|
|
|
|
self._setDefault(smoothing=1.0, modelType="multinomial")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-07-31 02:03:48 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
|
2016-10-12 22:52:57 -04:00
|
|
|
modelType="multinomial", thresholds=None, weightCol=None):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
2015-08-03 01:19:27 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
|
2016-10-12 22:52:57 -04:00
|
|
|
modelType="multinomial", thresholds=None, weightCol=None)
|
2015-07-31 02:03:48 -04:00
|
|
|
Sets params for Naive Bayes.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-07-31 02:03:48 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return NaiveBayesModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setSmoothing(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`smoothing`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(smoothing=value)
|
2015-07-31 02:03:48 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def getSmoothing(self):
|
|
|
|
"""
|
|
|
|
Gets the value of smoothing or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.smoothing)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setModelType(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`modelType`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(modelType=value)
|
2015-07-31 02:03:48 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def getModelType(self):
|
|
|
|
"""
|
|
|
|
Gets the value of modelType or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.modelType)
|
|
|
|
|
|
|
|
|
2016-08-22 06:21:22 -04:00
|
|
|
class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
Model fitted by NaiveBayes.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.5.0
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
2016-06-22 13:05:25 -04:00
|
|
|
@since("2.0.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def pi(self):
|
|
|
|
"""
|
|
|
|
log of class priors.
|
|
|
|
"""
|
|
|
|
return self._call_java("pi")
|
|
|
|
|
|
|
|
@property
|
2016-06-22 13:05:25 -04:00
|
|
|
@since("2.0.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def theta(self):
|
|
|
|
"""
|
|
|
|
log of class conditional probabilities.
|
|
|
|
"""
|
|
|
|
return self._call_java("theta")
|
|
|
|
|
|
|
|
|
2015-09-11 11:52:28 -04:00
|
|
|
@inherit_doc
|
|
|
|
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
2017-07-01 03:37:41 -04:00
|
|
|
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
|
2017-09-11 04:52:48 -04:00
|
|
|
JavaMLWritable, JavaMLReadable, HasProbabilityCol,
|
|
|
|
HasRawPredictionCol):
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
Classifier trainer based on the Multilayer Perceptron.
|
|
|
|
Each layer has sigmoid activation function, output layer has softmax.
|
|
|
|
Number of inputs has to be equal to the size of feature vectors.
|
|
|
|
Number of outputs has to be equal to the total number of labels.
|
|
|
|
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> df = spark.createDataFrame([
|
2015-09-11 11:52:28 -04:00
|
|
|
... (0.0, Vectors.dense([0.0, 0.0])),
|
|
|
|
... (1.0, Vectors.dense([0.0, 1.0])),
|
|
|
|
... (1.0, Vectors.dense([1.0, 0.0])),
|
|
|
|
... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
|
2016-06-03 18:56:17 -04:00
|
|
|
>>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 2, 2], blockSize=1, seed=123)
|
2015-09-11 11:52:28 -04:00
|
|
|
>>> model = mlp.fit(df)
|
|
|
|
>>> model.layers
|
2016-06-03 18:56:17 -04:00
|
|
|
[2, 2, 2]
|
2015-09-11 11:52:28 -04:00
|
|
|
>>> model.weights.size
|
2016-06-03 18:56:17 -04:00
|
|
|
12
|
2016-05-23 21:14:48 -04:00
|
|
|
>>> testDF = spark.createDataFrame([
|
2015-09-11 11:52:28 -04:00
|
|
|
... (Vectors.dense([1.0, 0.0]),),
|
|
|
|
... (Vectors.dense([0.0, 0.0]),)], ["features"])
|
2017-08-23 00:16:34 -04:00
|
|
|
>>> model.transform(testDF).select("features", "prediction").show()
|
2015-09-11 11:52:28 -04:00
|
|
|
+---------+----------+
|
|
|
|
| features|prediction|
|
|
|
|
+---------+----------+
|
|
|
|
|[1.0,0.0]| 1.0|
|
|
|
|
|[0.0,0.0]| 0.0|
|
|
|
|
+---------+----------+
|
|
|
|
...
|
2016-03-30 18:47:01 -04:00
|
|
|
>>> mlp_path = temp_path + "/mlp"
|
|
|
|
>>> mlp.save(mlp_path)
|
|
|
|
>>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
|
|
|
|
>>> mlp2.getBlockSize()
|
|
|
|
1
|
|
|
|
>>> model_path = temp_path + "/mlp_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
|
|
|
|
>>> model.layers == model2.layers
|
|
|
|
True
|
|
|
|
>>> model.weights == model2.weights
|
|
|
|
True
|
2016-06-03 18:56:17 -04:00
|
|
|
>>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))
|
|
|
|
>>> model3 = mlp2.fit(df)
|
|
|
|
>>> model3.weights != model2.weights
|
|
|
|
True
|
|
|
|
>>> model3.layers == model.layers
|
|
|
|
True
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.6.0
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
|
|
|
|
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
|
2016-05-01 15:29:01 -04:00
|
|
|
"neurons and output layer of 10 neurons.",
|
2016-03-23 14:20:44 -04:00
|
|
|
typeConverter=TypeConverters.toListInt)
|
2015-09-11 11:52:28 -04:00
|
|
|
blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " +
|
|
|
|
"matrices. Data is stacked within partitions. If block size is more than " +
|
|
|
|
"remaining data in a partition then it is adjusted to the size of this " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"data. Recommended size is between 10 and 1000, default is 128.",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
2016-06-03 18:56:17 -04:00
|
|
|
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
|
|
|
"options: l-bfgs, gd.", typeConverter=TypeConverters.toString)
|
|
|
|
initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
|
|
|
|
typeConverter=TypeConverters.toVector)
|
2015-09-11 11:52:28 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-07-25 15:00:37 -04:00
|
|
|
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
|
2017-09-11 04:52:48 -04:00
|
|
|
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
|
2018-04-11 18:52:13 -04:00
|
|
|
rawPredictionCol="rawPrediction"):
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-07-25 15:00:37 -04:00
|
|
|
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
|
2017-09-11 04:52:48 -04:00
|
|
|
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
|
2018-04-11 18:52:13 -04:00
|
|
|
rawPredictionCol="rawPrediction")
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
super(MultilayerPerceptronClassifier, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
|
2017-08-22 20:40:50 -04:00
|
|
|
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2015-09-11 11:52:28 -04:00
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-07-25 15:00:37 -04:00
|
|
|
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
|
2017-09-11 04:52:48 -04:00
|
|
|
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
|
2018-04-11 18:52:13 -04:00
|
|
|
rawPredictionCol="rawPrediction"):
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-07-25 15:00:37 -04:00
|
|
|
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
|
2017-09-11 04:52:48 -04:00
|
|
|
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
|
2018-04-11 18:52:13 -04:00
|
|
|
rawPredictionCol="rawPrediction"):
|
2015-09-11 11:52:28 -04:00
|
|
|
Sets params for MultilayerPerceptronClassifier.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2016-05-01 15:29:01 -04:00
|
|
|
return self._set(**kwargs)
|
2015-09-11 11:52:28 -04:00
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return MultilayerPerceptronClassificationModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setLayers(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`layers`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(layers=value)
|
2015-09-11 11:52:28 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def getLayers(self):
|
|
|
|
"""
|
|
|
|
Gets the value of layers or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.layers)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setBlockSize(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`blockSize`.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(blockSize=value)
|
2015-09-11 11:52:28 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def getBlockSize(self):
|
|
|
|
"""
|
|
|
|
Gets the value of blockSize or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.blockSize)
|
|
|
|
|
2016-06-03 18:56:17 -04:00
|
|
|
@since("2.0.0")
|
|
|
|
def setStepSize(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`stepSize`.
|
|
|
|
"""
|
|
|
|
return self._set(stepSize=value)
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def getStepSize(self):
|
|
|
|
"""
|
|
|
|
Gets the value of stepSize or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.stepSize)
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def setInitialWeights(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`initialWeights`.
|
|
|
|
"""
|
|
|
|
return self._set(initialWeights=value)
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def getInitialWeights(self):
|
|
|
|
"""
|
|
|
|
Gets the value of initialWeights or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.initialWeights)
|
|
|
|
|
2015-09-11 11:52:28 -04:00
|
|
|
|
2017-08-23 00:16:34 -04:00
|
|
|
class MultilayerPerceptronClassificationModel(JavaModel, JavaClassificationModel, JavaMLWritable,
|
2016-08-22 06:21:22 -04:00
|
|
|
JavaMLReadable):
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
Model fitted by MultilayerPerceptronClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.6.0
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def layers(self):
|
|
|
|
"""
|
|
|
|
array of layer sizes including input and output layers.
|
|
|
|
"""
|
|
|
|
return self._call_java("javaLayers")
|
|
|
|
|
|
|
|
@property
|
2016-06-22 13:05:25 -04:00
|
|
|
@since("2.0.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def weights(self):
|
|
|
|
"""
|
2016-09-06 06:30:37 -04:00
|
|
|
the weights of layers.
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
return self._call_java("weights")
|
|
|
|
|
|
|
|
|
2019-03-02 10:09:28 -05:00
|
|
|
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol,
|
|
|
|
HasRawPredictionCol):
|
2016-04-18 14:52:29 -04:00
|
|
|
"""
|
|
|
|
Parameters for OneVsRest and OneVsRestModel.
|
|
|
|
"""
|
|
|
|
|
|
|
|
classifier = Param(Params._dummy(), "classifier", "base binary classifier")
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def setClassifier(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`classifier`.
|
|
|
|
|
|
|
|
.. note:: Only LogisticRegression and NaiveBayes are supported now.
|
|
|
|
"""
|
2016-05-03 10:46:13 -04:00
|
|
|
return self._set(classifier=value)
|
2016-04-18 14:52:29 -04:00
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def getClassifier(self):
|
|
|
|
"""
|
|
|
|
Gets the value of classifier or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.classifier)
|
|
|
|
|
|
|
|
|
2016-04-15 15:58:38 -04:00
|
|
|
@inherit_doc
|
2017-09-12 13:02:27 -04:00
|
|
|
class OneVsRest(Estimator, OneVsRestParams, HasParallelism, JavaMLReadable, JavaMLWritable):
|
2016-04-15 15:58:38 -04:00
|
|
|
"""
|
|
|
|
Reduction of Multiclass Classification to Binary Classification.
|
|
|
|
Performs reduction using one against all strategy.
|
|
|
|
For a multiclass classification with k classes, train k models (one per class).
|
|
|
|
Each example is scored against all k models and the model with highest score
|
|
|
|
is picked to label the example.
|
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
2016-05-17 15:51:07 -04:00
|
|
|
>>> from pyspark.ml.linalg import Vectors
|
2017-04-26 09:34:18 -04:00
|
|
|
>>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
|
|
|
|
>>> df = spark.read.format("libsvm").load(data_path)
|
|
|
|
>>> lr = LogisticRegression(regParam=0.01)
|
2016-04-15 15:58:38 -04:00
|
|
|
>>> ovr = OneVsRest(classifier=lr)
|
2019-03-02 10:09:28 -05:00
|
|
|
>>> ovr.getRawPredictionCol()
|
|
|
|
'rawPrediction'
|
2016-04-15 15:58:38 -04:00
|
|
|
>>> model = ovr.fit(df)
|
2017-04-26 09:34:18 -04:00
|
|
|
>>> model.models[0].coefficients
|
|
|
|
DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
|
|
|
|
>>> model.models[1].coefficients
|
|
|
|
DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])
|
|
|
|
>>> model.models[2].coefficients
|
|
|
|
DenseVector([0.3..., -3.4..., 1.0..., -1.1...])
|
2016-04-15 15:58:38 -04:00
|
|
|
>>> [x.intercept for x in model.models]
|
2017-04-26 09:34:18 -04:00
|
|
|
[-2.7..., -2.5..., -1.3...]
|
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()
|
2016-04-15 15:58:38 -04:00
|
|
|
>>> model.transform(test0).head().prediction
|
|
|
|
0.0
|
2017-04-26 09:34:18 -04:00
|
|
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()
|
|
|
|
>>> model.transform(test1).head().prediction
|
2016-04-15 15:58:38 -04:00
|
|
|
2.0
|
2017-04-26 09:34:18 -04:00
|
|
|
>>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()
|
|
|
|
>>> model.transform(test2).head().prediction
|
|
|
|
0.0
|
2017-01-31 18:42:36 -05:00
|
|
|
>>> model_path = temp_path + "/ovr_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = OneVsRestModel.load(model_path)
|
|
|
|
>>> model2.transform(test0).head().prediction
|
2017-04-26 09:34:18 -04:00
|
|
|
0.0
|
2019-03-02 10:09:28 -05:00
|
|
|
>>> model.transform(test2).columns
|
|
|
|
['features', 'rawPrediction', 'prediction']
|
2016-04-15 15:58:38 -04:00
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
|
2016-04-15 15:58:38 -04:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
|
2016-04-15 15:58:38 -04:00
|
|
|
"""
|
|
|
|
super(OneVsRest, self).__init__()
|
2017-09-12 13:02:27 -04:00
|
|
|
self._setDefault(parallelism=1)
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2016-04-15 15:58:38 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
@since("2.0.0")
|
2017-09-12 13:02:27 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
|
2016-04-15 15:58:38 -04:00
|
|
|
"""
|
2017-09-12 13:02:27 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
|
2016-04-15 15:58:38 -04:00
|
|
|
Sets params for OneVsRest.
|
|
|
|
"""
|
2017-03-03 19:43:45 -05:00
|
|
|
kwargs = self._input_kwargs
|
2016-04-15 15:58:38 -04:00
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _fit(self, dataset):
|
|
|
|
labelCol = self.getLabelCol()
|
|
|
|
featuresCol = self.getFeaturesCol()
|
|
|
|
predictionCol = self.getPredictionCol()
|
|
|
|
classifier = self.getClassifier()
|
|
|
|
|
|
|
|
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
|
|
|
|
|
2017-07-27 22:10:35 -04:00
|
|
|
weightCol = None
|
|
|
|
if (self.isDefined(self.weightCol) and self.getWeightCol()):
|
|
|
|
if isinstance(classifier, HasWeightCol):
|
|
|
|
weightCol = self.getWeightCol()
|
|
|
|
else:
|
|
|
|
warnings.warn("weightCol is ignored, "
|
|
|
|
"as it is not supported by {} now.".format(classifier))
|
|
|
|
|
|
|
|
if weightCol:
|
|
|
|
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
|
|
|
|
else:
|
|
|
|
multiclassLabeled = dataset.select(labelCol, featuresCol)
|
2016-04-15 15:58:38 -04:00
|
|
|
|
|
|
|
# persist if underlying dataset is not persistent.
|
2017-09-14 02:09:44 -04:00
|
|
|
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
|
2016-04-15 15:58:38 -04:00
|
|
|
if handlePersistence:
|
|
|
|
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
|
|
|
|
|
def trainSingleClass(index):
|
|
|
|
binaryLabelCol = "mc2b$" + str(index)
|
|
|
|
trainingDataset = multiclassLabeled.withColumn(
|
|
|
|
binaryLabelCol,
|
|
|
|
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0))
|
|
|
|
paramMap = dict([(classifier.labelCol, binaryLabelCol),
|
|
|
|
(classifier.featuresCol, featuresCol),
|
|
|
|
(classifier.predictionCol, predictionCol)])
|
2017-07-27 22:10:35 -04:00
|
|
|
if weightCol:
|
|
|
|
paramMap[classifier.weightCol] = weightCol
|
2016-04-15 15:58:38 -04:00
|
|
|
return classifier.fit(trainingDataset, paramMap)
|
|
|
|
|
2017-09-12 13:02:27 -04:00
|
|
|
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
|
|
|
|
|
|
|
|
models = pool.map(trainSingleClass, range(numClasses))
|
2016-04-15 15:58:38 -04:00
|
|
|
|
|
|
|
if handlePersistence:
|
|
|
|
multiclassLabeled.unpersist()
|
|
|
|
|
|
|
|
return self._copyValues(OneVsRestModel(models=models))
|
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def copy(self, extra=None):
|
|
|
|
"""
|
|
|
|
Creates a copy of this instance with a randomly generated uid
|
|
|
|
and some extra params. This creates a deep copy of the embedded paramMap,
|
|
|
|
and copies the embedded and extra parameters over.
|
|
|
|
|
|
|
|
:param extra: Extra parameters to copy to the new instance
|
|
|
|
:return: Copy of this instance
|
|
|
|
"""
|
|
|
|
if extra is None:
|
|
|
|
extra = dict()
|
|
|
|
newOvr = Params.copy(self, extra)
|
|
|
|
if self.isSet(self.classifier):
|
|
|
|
newOvr.setClassifier(self.getClassifier().copy(extra))
|
|
|
|
return newOvr
|
|
|
|
|
2016-04-18 14:52:29 -04:00
|
|
|
@classmethod
|
|
|
|
def _from_java(cls, java_stage):
|
|
|
|
"""
|
|
|
|
Given a Java OneVsRest, create and return a Python wrapper of it.
|
|
|
|
Used for ML persistence.
|
|
|
|
"""
|
|
|
|
featuresCol = java_stage.getFeaturesCol()
|
|
|
|
labelCol = java_stage.getLabelCol()
|
|
|
|
predictionCol = java_stage.getPredictionCol()
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol = java_stage.getRawPredictionCol()
|
2016-04-18 14:52:29 -04:00
|
|
|
classifier = JavaParams._from_java(java_stage.getClassifier())
|
2017-09-12 13:02:27 -04:00
|
|
|
parallelism = java_stage.getParallelism()
|
2016-04-18 14:52:29 -04:00
|
|
|
py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol,
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol=rawPredictionCol, classifier=classifier,
|
|
|
|
parallelism=parallelism)
|
2016-04-18 14:52:29 -04:00
|
|
|
py_stage._resetUid(java_stage.uid())
|
|
|
|
return py_stage
|
|
|
|
|
|
|
|
def _to_java(self):
|
|
|
|
"""
|
|
|
|
Transfer this instance to a Java OneVsRest. Used for ML persistence.
|
|
|
|
|
|
|
|
:return: Java object equivalent to this instance.
|
|
|
|
"""
|
|
|
|
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
|
|
|
|
self.uid)
|
|
|
|
_java_obj.setClassifier(self.getClassifier()._to_java())
|
2017-09-12 13:02:27 -04:00
|
|
|
_java_obj.setParallelism(self.getParallelism())
|
2016-04-18 14:52:29 -04:00
|
|
|
_java_obj.setFeaturesCol(self.getFeaturesCol())
|
|
|
|
_java_obj.setLabelCol(self.getLabelCol())
|
|
|
|
_java_obj.setPredictionCol(self.getPredictionCol())
|
2019-03-02 10:09:28 -05:00
|
|
|
_java_obj.setRawPredictionCol(self.getRawPredictionCol())
|
2016-04-18 14:52:29 -04:00
|
|
|
return _java_obj
|
2016-04-15 15:58:38 -04:00
|
|
|
|
2017-07-17 13:07:32 -04:00
|
|
|
def _make_java_param_pair(self, param, value):
|
|
|
|
"""
|
|
|
|
Makes a Java param pair.
|
|
|
|
"""
|
|
|
|
sc = SparkContext._active_spark_context
|
|
|
|
param = self._resolveParam(param)
|
|
|
|
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
|
|
|
|
self.uid)
|
|
|
|
java_param = _java_obj.getParam(param.name)
|
|
|
|
if isinstance(value, JavaParams):
|
|
|
|
# used in the case of an estimator having another estimator as a parameter
|
|
|
|
# the reason why this is not in _py2java in common.py is that importing
|
|
|
|
# Estimator and Model in common.py results in a circular import with inherit_doc
|
|
|
|
java_value = value._to_java()
|
|
|
|
else:
|
|
|
|
java_value = _py2java(sc, value)
|
|
|
|
return java_param.w(java_value)
|
|
|
|
|
|
|
|
def _transfer_param_map_to_java(self, pyParamMap):
|
|
|
|
"""
|
|
|
|
Transforms a Python ParamMap into a Java ParamMap.
|
|
|
|
"""
|
|
|
|
paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
|
|
|
|
for param in self.params:
|
|
|
|
if param in pyParamMap:
|
|
|
|
pair = self._make_java_param_pair(param, pyParamMap[param])
|
|
|
|
paramMap.put([pair])
|
|
|
|
return paramMap
|
2016-04-18 14:52:29 -04:00
|
|
|
|
2017-07-17 13:07:32 -04:00
|
|
|
def _transfer_param_map_from_java(self, javaParamMap):
|
|
|
|
"""
|
|
|
|
Transforms a Java ParamMap into a Python ParamMap.
|
|
|
|
"""
|
|
|
|
sc = SparkContext._active_spark_context
|
|
|
|
paramMap = dict()
|
|
|
|
for pair in javaParamMap.toList():
|
|
|
|
param = pair.param()
|
|
|
|
if self.hasParam(str(param.name())):
|
|
|
|
if param.name() == "classifier":
|
|
|
|
paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
|
|
|
|
else:
|
|
|
|
paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
|
|
|
|
return paramMap
|
|
|
|
|
|
|
|
|
|
|
|
class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable):
|
2016-04-15 15:58:38 -04:00
|
|
|
"""
|
|
|
|
Model fitted by OneVsRest.
|
|
|
|
This stores the models resulting from training k binary classifiers: one for each class.
|
|
|
|
Each example is scored against all k models, and the model with the highest score
|
|
|
|
is picked to label the example.
|
|
|
|
|
|
|
|
.. versionadded:: 2.0.0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, models):
|
|
|
|
super(OneVsRestModel, self).__init__()
|
|
|
|
self.models = models
|
2017-07-17 13:07:32 -04:00
|
|
|
java_models = [model._to_java() for model in self.models]
|
|
|
|
sc = SparkContext._active_spark_context
|
|
|
|
java_models_array = JavaWrapper._new_java_array(java_models,
|
|
|
|
sc._gateway.jvm.org.apache.spark.ml
|
|
|
|
.classification.ClassificationModel)
|
|
|
|
# TODO: need to set metadata
|
|
|
|
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
|
|
|
|
self._java_obj = \
|
|
|
|
JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
|
|
|
|
self.uid, metadata.empty(), java_models_array)
|
2016-04-15 15:58:38 -04:00
|
|
|
|
|
|
|
def _transform(self, dataset):
|
|
|
|
# determine the input columns: these need to be passed through
|
|
|
|
origCols = dataset.columns
|
|
|
|
|
|
|
|
# add an accumulator column to store predictions of all the models
|
|
|
|
accColName = "mbc$acc" + str(uuid.uuid4())
|
|
|
|
initUDF = udf(lambda _: [], ArrayType(DoubleType()))
|
|
|
|
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
|
|
|
|
|
|
|
|
# persist if underlying dataset is not persistent.
|
2017-09-14 02:09:44 -04:00
|
|
|
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
|
2016-04-15 15:58:38 -04:00
|
|
|
if handlePersistence:
|
|
|
|
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
|
|
|
|
|
# update the accumulator column with the result of prediction of models
|
|
|
|
aggregatedDataset = newDataset
|
|
|
|
for index, model in enumerate(self.models):
|
2019-03-02 10:09:28 -05:00
|
|
|
rawPredictionCol = self.getRawPredictionCol()
|
|
|
|
|
2016-04-15 15:58:38 -04:00
|
|
|
columns = origCols + [rawPredictionCol, accColName]
|
|
|
|
|
|
|
|
# add temporary column to store intermediate scores and update
|
|
|
|
tmpColName = "mbc$tmp" + str(uuid.uuid4())
|
|
|
|
updateUDF = udf(
|
|
|
|
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
|
|
|
|
ArrayType(DoubleType()))
|
|
|
|
transformedDataset = model.transform(aggregatedDataset).select(*columns)
|
|
|
|
updatedDataset = transformedDataset.withColumn(
|
|
|
|
tmpColName,
|
|
|
|
updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
|
|
|
|
newColumns = origCols + [tmpColName]
|
|
|
|
|
|
|
|
# switch out the intermediate column with the accumulator column
|
|
|
|
aggregatedDataset = updatedDataset\
|
|
|
|
.select(*newColumns).withColumnRenamed(tmpColName, accColName)
|
|
|
|
|
|
|
|
if handlePersistence:
|
|
|
|
newDataset.unpersist()
|
|
|
|
|
2019-03-02 10:09:28 -05:00
|
|
|
if self.getRawPredictionCol():
|
|
|
|
def func(predictions):
|
|
|
|
predArray = []
|
|
|
|
for x in predictions:
|
|
|
|
predArray.append(x)
|
|
|
|
return Vectors.dense(predArray)
|
|
|
|
|
|
|
|
rawPredictionUDF = udf(func)
|
|
|
|
aggregatedDataset = aggregatedDataset.withColumn(
|
|
|
|
self.getRawPredictionCol(), rawPredictionUDF(aggregatedDataset[accColName]))
|
|
|
|
|
|
|
|
if self.getPredictionCol():
|
|
|
|
# output the index of the classifier with highest confidence as prediction
|
|
|
|
labelUDF = udf(lambda predictions: float(max(enumerate(predictions),
|
|
|
|
key=operator.itemgetter(1))[0]), DoubleType())
|
|
|
|
aggregatedDataset = aggregatedDataset.withColumn(
|
|
|
|
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName]))
|
|
|
|
return aggregatedDataset.drop(accColName)
|
2016-04-15 15:58:38 -04:00
|
|
|
|
|
|
|
@since("2.0.0")
|
|
|
|
def copy(self, extra=None):
|
|
|
|
"""
|
|
|
|
Creates a copy of this instance with a randomly generated uid
|
|
|
|
and some extra params. This creates a deep copy of the embedded paramMap,
|
|
|
|
and copies the embedded and extra parameters over.
|
|
|
|
|
|
|
|
:param extra: Extra parameters to copy to the new instance
|
|
|
|
:return: Copy of this instance
|
|
|
|
"""
|
|
|
|
if extra is None:
|
|
|
|
extra = dict()
|
|
|
|
newModel = Params.copy(self, extra)
|
|
|
|
newModel.models = [model.copy(extra) for model in self.models]
|
|
|
|
return newModel
|
|
|
|
|
2016-04-18 14:52:29 -04:00
|
|
|
@classmethod
|
|
|
|
def _from_java(cls, java_stage):
|
|
|
|
"""
|
|
|
|
Given a Java OneVsRestModel, create and return a Python wrapper of it.
|
|
|
|
Used for ML persistence.
|
|
|
|
"""
|
|
|
|
featuresCol = java_stage.getFeaturesCol()
|
|
|
|
labelCol = java_stage.getLabelCol()
|
|
|
|
predictionCol = java_stage.getPredictionCol()
|
|
|
|
classifier = JavaParams._from_java(java_stage.getClassifier())
|
|
|
|
models = [JavaParams._from_java(model) for model in java_stage.models()]
|
|
|
|
py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
|
|
|
|
.setFeaturesCol(featuresCol).setClassifier(classifier)
|
|
|
|
py_stage._resetUid(java_stage.uid())
|
|
|
|
return py_stage
|
|
|
|
|
|
|
|
def _to_java(self):
|
|
|
|
"""
|
|
|
|
Transfer this instance to a Java OneVsRestModel. Used for ML persistence.
|
|
|
|
|
|
|
|
:return: Java object equivalent to this instance.
|
|
|
|
"""
|
2017-01-31 18:42:36 -05:00
|
|
|
sc = SparkContext._active_spark_context
|
2016-04-18 14:52:29 -04:00
|
|
|
java_models = [model._to_java() for model in self.models]
|
2017-01-31 18:42:36 -05:00
|
|
|
java_models_array = JavaWrapper._new_java_array(
|
|
|
|
java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel)
|
|
|
|
metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
|
2016-04-18 14:52:29 -04:00
|
|
|
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
|
2017-01-31 18:42:36 -05:00
|
|
|
self.uid, metadata.empty(), java_models_array)
|
2016-04-18 14:52:29 -04:00
|
|
|
_java_obj.set("classifier", self.getClassifier()._to_java())
|
|
|
|
_java_obj.set("featuresCol", self.getFeaturesCol())
|
|
|
|
_java_obj.set("labelCol", self.getLabelCol())
|
|
|
|
_java_obj.set("predictionCol", self.getPredictionCol())
|
|
|
|
return _java_obj
|
|
|
|
|
2016-04-15 15:58:38 -04:00
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import doctest
|
2016-03-16 17:21:42 -04:00
|
|
|
import pyspark.ml.classification
|
2016-05-23 21:14:48 -04:00
|
|
|
from pyspark.sql import SparkSession
|
2016-03-16 17:21:42 -04:00
|
|
|
globs = pyspark.ml.classification.__dict__.copy()
|
2015-01-28 20:14:23 -05:00
|
|
|
# The small batch size here ensures that we see multiple batches,
|
|
|
|
# even in these small test examples:
|
2016-05-23 21:14:48 -04:00
|
|
|
spark = SparkSession.builder\
|
|
|
|
.master("local[2]")\
|
|
|
|
.appName("ml.classification tests")\
|
|
|
|
.getOrCreate()
|
|
|
|
sc = spark.sparkContext
|
2015-01-28 20:14:23 -05:00
|
|
|
globs['sc'] = sc
|
2016-05-23 21:14:48 -04:00
|
|
|
globs['spark'] = spark
|
2016-03-16 17:21:42 -04:00
|
|
|
import tempfile
|
|
|
|
temp_path = tempfile.mkdtemp()
|
|
|
|
globs['temp_path'] = temp_path
|
|
|
|
try:
|
|
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
2016-05-23 21:14:48 -04:00
|
|
|
spark.stop()
|
2016-03-16 17:21:42 -04:00
|
|
|
finally:
|
|
|
|
from shutil import rmtree
|
|
|
|
try:
|
|
|
|
rmtree(temp_path)
|
|
|
|
except OSError:
|
|
|
|
pass
|
2015-01-28 20:14:23 -05:00
|
|
|
if failure_count:
|
2018-03-08 06:38:34 -05:00
|
|
|
sys.exit(-1)
|