2015-01-28 20:14:23 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2015-11-02 19:12:04 -05:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
from pyspark import since
|
2016-03-16 17:21:42 -04:00
|
|
|
from pyspark.ml.util import *
|
2015-01-28 20:14:23 -05:00
|
|
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel
|
2016-03-23 14:20:44 -04:00
|
|
|
from pyspark.ml.param import TypeConverters
|
2015-05-13 18:13:09 -04:00
|
|
|
from pyspark.ml.param.shared import *
|
2015-07-07 11:58:08 -04:00
|
|
|
from pyspark.ml.regression import (
|
2015-10-27 16:55:03 -04:00
|
|
|
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
|
2015-02-20 05:31:32 -05:00
|
|
|
from pyspark.mllib.common import inherit_doc
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
|
2016-03-02 00:26:47 -05:00
|
|
|
__all__ = ['LogisticRegression', 'LogisticRegressionModel',
|
|
|
|
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
|
|
|
|
'GBTClassifier', 'GBTClassificationModel',
|
|
|
|
'RandomForestClassifier', 'RandomForestClassificationModel',
|
|
|
|
'NaiveBayes', 'NaiveBayesModel',
|
|
|
|
'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel']
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
2015-09-11 11:50:35 -04:00
|
|
|
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
|
2015-11-18 16:32:06 -05:00
|
|
|
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
|
2016-03-22 15:11:23 -04:00
|
|
|
HasWeightCol, JavaMLWritable, JavaMLReadable):
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
Logistic regression.
|
2015-08-12 16:24:18 -04:00
|
|
|
Currently, this class only supports binary classification.
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
2015-02-15 23:29:26 -05:00
|
|
|
>>> df = sc.parallelize([
|
2015-11-18 16:32:06 -05:00
|
|
|
... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
|
|
|
|
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
|
|
|
|
>>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
|
2015-02-15 23:29:26 -05:00
|
|
|
>>> model = lr.fit(df)
|
2015-12-03 14:37:34 -05:00
|
|
|
>>> model.coefficients
|
2015-05-14 21:13:58 -04:00
|
|
|
DenseVector([5.5...])
|
|
|
|
>>> model.intercept
|
|
|
|
-2.68...
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
|
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
|
|
|
0.0
|
|
|
|
>>> result.probability
|
|
|
|
DenseVector([0.99..., 0.00...])
|
|
|
|
>>> result.rawPrediction
|
|
|
|
DenseVector([8.22..., -8.22...])
|
2015-02-15 23:29:26 -05:00
|
|
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
|
2015-04-16 19:20:57 -04:00
|
|
|
>>> model.transform(test1).head().prediction
|
2015-01-28 20:14:23 -05:00
|
|
|
1.0
|
2015-02-15 23:29:26 -05:00
|
|
|
>>> lr.setParams("vector")
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
TypeError: Method setParams forces keyword arguments.
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> lr_path = temp_path + "/lr"
|
|
|
|
>>> lr.save(lr_path)
|
|
|
|
>>> lr2 = LogisticRegression.load(lr_path)
|
|
|
|
>>> lr2.getMaxIter()
|
|
|
|
5
|
|
|
|
>>> model_path = temp_path + "/lr_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = LogisticRegressionModel.load(model_path)
|
|
|
|
>>> model.coefficients[0] == model2.coefficients[0]
|
|
|
|
True
|
|
|
|
>>> model.intercept == model2.intercept
|
|
|
|
True
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.3.0
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
2015-05-18 15:02:18 -04:00
|
|
|
|
2015-08-12 17:27:13 -04:00
|
|
|
threshold = Param(Params._dummy(), "threshold",
|
|
|
|
"Threshold in binary classification prediction, in range [0, 1]." +
|
2016-03-23 14:20:44 -04:00
|
|
|
" If threshold and thresholds are both set, they must match.",
|
|
|
|
typeConverter=TypeConverters.toFloat)
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2015-02-15 23:29:26 -05:00
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability",
|
2015-11-18 16:32:06 -05:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
2015-11-18 16:32:06 -05:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
|
2015-08-12 17:27:13 -04:00
|
|
|
If the threshold and thresholds Params are both set, they must be equivalent.
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
|
|
|
super(LogisticRegression, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.LogisticRegression", self.uid)
|
2016-03-04 11:25:41 -05:00
|
|
|
self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5)
|
2015-02-15 23:29:26 -05:00
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
2015-08-12 17:27:13 -04:00
|
|
|
self._checkThresholdConsistency()
|
2015-02-15 23:29:26 -05:00
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.3.0")
|
2015-02-15 23:29:26 -05:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability",
|
2015-11-18 16:32:06 -05:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2016-03-04 11:25:41 -05:00
|
|
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
2015-09-11 11:50:35 -04:00
|
|
|
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
2015-11-18 16:32:06 -05:00
|
|
|
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
|
2015-02-15 23:29:26 -05:00
|
|
|
Sets params for logistic regression.
|
2015-08-12 17:27:13 -04:00
|
|
|
If the threshold and thresholds Params are both set, they must be equivalent.
|
2015-02-15 23:29:26 -05:00
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
2015-08-12 17:27:13 -04:00
|
|
|
self._set(**kwargs)
|
|
|
|
self._checkThresholdConsistency()
|
|
|
|
return self
|
2015-02-15 23:29:26 -05:00
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
def _create_model(self, java_model):
|
|
|
|
return LogisticRegressionModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setThreshold(self, value):
|
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
Sets the value of :py:attr:`threshold`.
|
|
|
|
Clears value of :py:attr:`thresholds` if it has been set.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.threshold] = value
|
|
|
|
if self.isSet(self.thresholds):
|
|
|
|
del self._paramMap[self.thresholds]
|
|
|
|
return self
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-08-12 17:27:13 -04:00
|
|
|
def getThreshold(self):
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
Gets the value of threshold or its default value.
|
|
|
|
"""
|
|
|
|
self._checkThresholdConsistency()
|
|
|
|
if self.isSet(self.thresholds):
|
|
|
|
ts = self.getOrDefault(self.thresholds)
|
|
|
|
if len(ts) != 2:
|
|
|
|
raise ValueError("Logistic Regression getThreshold only applies to" +
|
|
|
|
" binary classification, but thresholds has length != 2." +
|
|
|
|
" thresholds: " + ",".join(ts))
|
|
|
|
return 1.0/(1.0 + ts[0]/ts[1])
|
|
|
|
else:
|
|
|
|
return self.getOrDefault(self.threshold)
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
def setThresholds(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`thresholds`.
|
2015-08-12 17:27:13 -04:00
|
|
|
Clears value of :py:attr:`threshold` if it has been set.
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
self._paramMap[self.thresholds] = value
|
2015-08-12 17:27:13 -04:00
|
|
|
if self.isSet(self.threshold):
|
|
|
|
del self._paramMap[self.threshold]
|
2015-05-13 18:13:09 -04:00
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
def getThresholds(self):
|
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
If :py:attr:`thresholds` is set, return its value.
|
|
|
|
Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary
|
|
|
|
classification: (1-threshold, threshold).
|
|
|
|
If neither are set, throw an error.
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
"""
|
2015-08-12 17:27:13 -04:00
|
|
|
self._checkThresholdConsistency()
|
|
|
|
if not self.isSet(self.thresholds) and self.isSet(self.threshold):
|
|
|
|
t = self.getOrDefault(self.threshold)
|
|
|
|
return [1.0-t, t]
|
|
|
|
else:
|
|
|
|
return self.getOrDefault(self.thresholds)
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
|
2015-08-12 17:27:13 -04:00
|
|
|
def _checkThresholdConsistency(self):
|
|
|
|
if self.isSet(self.threshold) and self.isSet(self.thresholds):
|
|
|
|
ts = self.getParam(self.thresholds)
|
|
|
|
if len(ts) != 2:
|
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification.
Note that the primary author of this PR is holdenk
Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:
3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
2015-08-04 13:12:22 -04:00
|
|
|
raise ValueError("Logistic Regression getThreshold only applies to" +
|
|
|
|
" binary classification, but thresholds has length != 2." +
|
2015-08-12 17:27:13 -04:00
|
|
|
" thresholds: " + ",".join(ts))
|
|
|
|
t = 1.0/(1.0 + ts[0]/ts[1])
|
|
|
|
t2 = self.getParam(self.threshold)
|
|
|
|
if abs(t2 - t) >= 1E-5:
|
|
|
|
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
|
|
|
|
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2016-03-22 15:11:23 -04:00
|
|
|
class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
Model fitted by LogisticRegression.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.3.0
|
2015-01-28 20:14:23 -05:00
|
|
|
"""
|
|
|
|
|
2015-05-14 21:13:58 -04:00
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-14 21:13:58 -04:00
|
|
|
def weights(self):
|
|
|
|
"""
|
|
|
|
Model weights.
|
|
|
|
"""
|
2015-11-02 19:12:04 -05:00
|
|
|
|
|
|
|
warnings.warn("weights is deprecated. Use coefficients instead.")
|
2015-05-14 21:13:58 -04:00
|
|
|
return self._call_java("weights")
|
|
|
|
|
2015-11-02 19:12:04 -05:00
|
|
|
@property
|
|
|
|
@since("1.6.0")
|
|
|
|
def coefficients(self):
|
|
|
|
"""
|
|
|
|
Model coefficients.
|
|
|
|
"""
|
|
|
|
return self._call_java("coefficients")
|
|
|
|
|
2015-05-14 21:13:58 -04:00
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-14 21:13:58 -04:00
|
|
|
def intercept(self):
|
|
|
|
"""
|
|
|
|
Model intercept.
|
|
|
|
"""
|
|
|
|
return self._call_java("intercept")
|
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
class TreeClassifierParams(object):
|
|
|
|
"""
|
|
|
|
Private class to track supported impurity measures.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
supportedImpurities = ["entropy", "gini"]
|
|
|
|
|
2015-10-27 16:55:03 -04:00
|
|
|
impurity = Param(Params._dummy(), "impurity",
|
|
|
|
"Criterion used for information gain calculation (case-insensitive). " +
|
|
|
|
"Supported options: " +
|
2016-03-23 14:20:44 -04:00
|
|
|
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
2015-10-27 16:55:03 -04:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super(TreeClassifierParams, self).__init__()
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-10-27 16:55:03 -04:00
|
|
|
def setImpurity(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`impurity`.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.impurity] = value
|
|
|
|
return self
|
2015-05-13 18:13:09 -04:00
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-10-27 16:55:03 -04:00
|
|
|
def getImpurity(self):
|
|
|
|
"""
|
|
|
|
Gets the value of impurity or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.impurity)
|
|
|
|
|
|
|
|
|
|
|
|
class GBTParams(TreeEnsembleParams):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Private class to track supported GBT params.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
supportedLossTypes = ["logistic"]
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
2015-08-03 01:19:27 -04:00
|
|
|
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
|
2016-03-24 22:20:49 -04:00
|
|
|
TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
|
|
|
|
JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
|
|
|
|
learning algorithm for classification.
|
|
|
|
It supports both binary and multiclass labels, as well as both continuous and categorical
|
|
|
|
features.
|
|
|
|
|
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
|
|
|
>>> df = sqlContext.createDataFrame([
|
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
|
|
|
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
|
|
|
|
>>> model = dt.fit(td)
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> model.numNodes
|
|
|
|
3
|
|
|
|
>>> model.depth
|
|
|
|
1
|
2016-03-11 02:54:23 -05:00
|
|
|
>>> model.featureImportances
|
|
|
|
SparseVector(1, {0: 1.0})
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-05-13 18:13:09 -04:00
|
|
|
0.0
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.probability
|
|
|
|
DenseVector([1.0, 0.0])
|
|
|
|
>>> result.rawPrediction
|
|
|
|
DenseVector([1.0, 0.0])
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2015-11-09 16:16:04 -05:00
|
|
|
|
2016-03-24 22:20:49 -04:00
|
|
|
>>> dtc_path = temp_path + "/dtc"
|
|
|
|
>>> dt.save(dtc_path)
|
|
|
|
>>> dt2 = DecisionTreeClassifier.load(dtc_path)
|
|
|
|
>>> dt2.getMaxDepth()
|
|
|
|
2
|
|
|
|
>>> model_path = temp_path + "/dtc_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = DecisionTreeClassificationModel.load(model_path)
|
|
|
|
>>> model.featureImportances == model2.featureImportances
|
|
|
|
True
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
|
|
|
seed=None):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
|
|
|
seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(DecisionTreeClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
|
|
|
impurity="gini")
|
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2016-01-06 13:52:25 -05:00
|
|
|
impurity="gini", seed=None):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2016-01-06 13:52:25 -05:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
|
|
|
seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
Sets params for the DecisionTreeClassifier.
|
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return DecisionTreeClassificationModel(java_model)
|
|
|
|
|
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
@inherit_doc
|
2016-03-24 22:20:49 -04:00
|
|
|
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by DecisionTreeClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
2016-03-11 02:54:23 -05:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featureImportances(self):
|
|
|
|
"""
|
|
|
|
Estimate of the importance of each feature.
|
|
|
|
|
|
|
|
This generalizes the idea of "Gini" importance to other losses,
|
|
|
|
following the explanation of Gini importance from "Random Forests" documentation
|
|
|
|
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
|
|
|
|
|
|
|
This feature importance is calculated as follows:
|
|
|
|
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
|
|
|
where gain is scaled by the number of instances passing through node
|
|
|
|
- Normalize importances for tree to sum to 1.
|
|
|
|
|
|
|
|
Note: Feature importance for single decision trees can have high variance due to
|
|
|
|
correlated predictor variables. Consider using a :class:`RandomForestClassifier`
|
|
|
|
to determine feature importance instead.
|
|
|
|
"""
|
|
|
|
return self._call_java("featureImportances")
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
|
2015-08-04 17:54:26 -04:00
|
|
|
HasRawPredictionCol, HasProbabilityCol,
|
2015-10-27 16:55:03 -04:00
|
|
|
RandomForestParams, TreeClassifierParams, HasCheckpointInterval):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
|
|
|
|
learning algorithm for classification.
|
|
|
|
It supports both binary and multiclass labels, as well as both continuous and categorical
|
|
|
|
features.
|
|
|
|
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> import numpy
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> from numpy import allclose
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
|
|
|
>>> df = sqlContext.createDataFrame([
|
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
2015-07-29 21:18:29 -04:00
|
|
|
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model = rf.fit(td)
|
2016-03-11 02:54:23 -05:00
|
|
|
>>> model.featureImportances
|
|
|
|
SparseVector(1, {0: 1.0})
|
2015-07-29 21:18:29 -04:00
|
|
|
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
|
2015-07-07 11:58:08 -04:00
|
|
|
True
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-05-13 18:13:09 -04:00
|
|
|
0.0
|
2015-08-04 17:54:26 -04:00
|
|
|
>>> numpy.argmax(result.probability)
|
|
|
|
0
|
|
|
|
>>> numpy.argmax(result.rawPrediction)
|
|
|
|
0
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
2015-05-20 18:16:12 -04:00
|
|
|
numTrees=20, featureSubsetStrategy="auto", seed=None):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
2015-05-20 18:16:12 -04:00
|
|
|
numTrees=20, featureSubsetStrategy="auto", seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(RandomForestClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2015-05-20 18:16:12 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
|
2015-05-13 18:13:09 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto")
|
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
2015-05-13 18:13:09 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
2015-05-20 18:16:12 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
|
2015-05-13 18:13:09 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
|
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
2015-08-04 17:54:26 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
2015-05-14 21:16:22 -04:00
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
2015-05-20 18:16:12 -04:00
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
|
2015-05-13 18:13:09 -04:00
|
|
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto")
|
|
|
|
Sets params for linear classification.
|
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return RandomForestClassificationModel(java_model)
|
|
|
|
|
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
class RandomForestClassificationModel(TreeEnsembleModels):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by RandomForestClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
2016-03-11 02:54:23 -05:00
|
|
|
@property
|
|
|
|
@since("2.0.0")
|
|
|
|
def featureImportances(self):
|
|
|
|
"""
|
|
|
|
Estimate of the importance of each feature.
|
|
|
|
|
|
|
|
This generalizes the idea of "Gini" importance to other losses,
|
|
|
|
following the explanation of Gini importance from "Random Forests" documentation
|
|
|
|
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
|
|
|
|
|
|
|
|
This feature importance is calculated as follows:
|
|
|
|
- Average over trees:
|
|
|
|
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
|
|
|
|
where gain is scaled by the number of instances passing through node
|
|
|
|
- Normalize importances for tree to sum to 1.
|
|
|
|
- Normalize feature importance vector to sum to 1.
|
|
|
|
"""
|
|
|
|
return self._call_java("featureImportances")
|
|
|
|
|
2015-05-13 18:13:09 -04:00
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
2015-10-27 16:55:03 -04:00
|
|
|
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
`http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)`
|
|
|
|
learning algorithm for classification.
|
|
|
|
It supports binary labels, as well as both continuous and categorical features.
|
|
|
|
Note: Multiclass labels are not currently supported.
|
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> from numpy import allclose
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
|
|
|
>>> from pyspark.ml.feature import StringIndexer
|
|
|
|
>>> df = sqlContext.createDataFrame([
|
|
|
|
... (1.0, Vectors.dense(1.0)),
|
|
|
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
|
|
|
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
|
|
|
|
>>> si_model = stringIndexer.fit(df)
|
|
|
|
>>> td = si_model.transform(df)
|
2016-03-24 22:14:24 -04:00
|
|
|
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> model = gbt.fit(td)
|
2015-07-07 11:58:08 -04:00
|
|
|
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
|
|
|
|
True
|
2015-05-13 18:13:09 -04:00
|
|
|
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
|
|
|
>>> model.transform(test0).head().prediction
|
|
|
|
0.0
|
|
|
|
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
lossType = Param(Params._dummy(), "lossType",
|
|
|
|
"Loss function which GBT tries to minimize (case-insensitive). " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"Supported options: " + ", ".join(GBTParams.supportedLossTypes),
|
|
|
|
typeConverter=TypeConverters.toString)
|
2015-05-13 18:13:09 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
|
2016-03-24 22:14:24 -04:00
|
|
|
maxIter=20, stepSize=0.1, seed=None):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
2016-03-24 22:14:24 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
super(GBTClassifier, self).__init__()
|
2015-05-18 15:02:18 -04:00
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
|
2015-05-13 18:13:09 -04:00
|
|
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2016-03-24 22:14:24 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
2016-03-24 22:14:24 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
2015-05-14 21:16:22 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
|
|
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
2016-03-24 22:14:24 -04:00
|
|
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
|
2015-05-13 18:13:09 -04:00
|
|
|
Sets params for Gradient Boosted Tree Classification.
|
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return GBTClassificationModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def setLossType(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`lossType`.
|
|
|
|
"""
|
2015-05-18 15:02:18 -04:00
|
|
|
self._paramMap[self.lossType] = value
|
2015-05-13 18:13:09 -04:00
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.4.0")
|
2015-05-13 18:13:09 -04:00
|
|
|
def getLossType(self):
|
|
|
|
"""
|
|
|
|
Gets the value of lossType or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.lossType)
|
|
|
|
|
|
|
|
|
2015-07-07 11:58:08 -04:00
|
|
|
class GBTClassificationModel(TreeEnsembleModels):
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
Model fitted by GBTClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.4.0
|
2015-05-13 18:13:09 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
2015-07-31 02:03:48 -04:00
|
|
|
@inherit_doc
|
2015-08-03 01:19:27 -04:00
|
|
|
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
|
2016-03-22 15:11:23 -04:00
|
|
|
HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
Naive Bayes Classifiers.
|
2015-08-12 16:24:18 -04:00
|
|
|
It supports both Multinomial and Bernoulli NB. Multinomial NB
|
|
|
|
(`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`)
|
|
|
|
can handle finitely supported discrete data. For example, by converting documents into
|
|
|
|
TF-IDF vectors, it can be used for document classification. By making every vector a
|
|
|
|
binary (0/1) data, it can also be used as Bernoulli NB
|
|
|
|
(`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`).
|
|
|
|
The input feature values must be nonnegative.
|
2015-07-31 02:03:48 -04:00
|
|
|
|
|
|
|
>>> from pyspark.sql import Row
|
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
|
|
|
>>> df = sqlContext.createDataFrame([
|
|
|
|
... Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
|
|
|
|
... Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
|
|
|
|
... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
|
|
|
|
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
|
|
|
|
>>> model = nb.fit(df)
|
|
|
|
>>> model.pi
|
|
|
|
DenseVector([-0.51..., -0.91...])
|
|
|
|
>>> model.theta
|
|
|
|
DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1)
|
|
|
|
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result = model.transform(test0).head()
|
|
|
|
>>> result.prediction
|
2015-07-31 02:03:48 -04:00
|
|
|
1.0
|
2015-08-03 01:19:27 -04:00
|
|
|
>>> result.probability
|
|
|
|
DenseVector([0.42..., 0.57...])
|
|
|
|
>>> result.rawPrediction
|
|
|
|
DenseVector([-1.60..., -1.32...])
|
2015-07-31 02:03:48 -04:00
|
|
|
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
|
|
|
|
>>> model.transform(test1).head().prediction
|
|
|
|
1.0
|
2016-03-16 17:21:42 -04:00
|
|
|
>>> nb_path = temp_path + "/nb"
|
|
|
|
>>> nb.save(nb_path)
|
|
|
|
>>> nb2 = NaiveBayes.load(nb_path)
|
|
|
|
>>> nb2.getSmoothing()
|
|
|
|
1.0
|
|
|
|
>>> model_path = temp_path + "/nb_model"
|
|
|
|
>>> model.save(model_path)
|
|
|
|
>>> model2 = NaiveBayesModel.load(model_path)
|
|
|
|
>>> model.pi == model2.pi
|
|
|
|
True
|
|
|
|
>>> model.theta == model2.theta
|
|
|
|
True
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.5.0
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"default is 1.0", typeConverter=TypeConverters.toFloat)
|
2015-07-31 02:03:48 -04:00
|
|
|
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"(case-sensitive). Supported options: multinomial (default) and bernoulli.",
|
|
|
|
typeConverter=TypeConverters.toString)
|
2015-07-31 02:03:48 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
|
|
|
|
modelType="multinomial"):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
2015-08-03 01:19:27 -04:00
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
|
|
|
|
modelType="multinomial")
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
super(NaiveBayes, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.NaiveBayes", self.uid)
|
|
|
|
self._setDefault(smoothing=1.0, modelType="multinomial")
|
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
2015-08-03 01:19:27 -04:00
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
|
|
|
|
modelType="multinomial"):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
2015-08-03 01:19:27 -04:00
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
|
|
|
|
modelType="multinomial")
|
2015-07-31 02:03:48 -04:00
|
|
|
Sets params for Naive Bayes.
|
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return NaiveBayesModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setSmoothing(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`smoothing`.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.smoothing] = value
|
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def getSmoothing(self):
|
|
|
|
"""
|
|
|
|
Gets the value of smoothing or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.smoothing)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def setModelType(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`modelType`.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.modelType] = value
|
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def getModelType(self):
|
|
|
|
"""
|
|
|
|
Gets the value of modelType or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.modelType)
|
|
|
|
|
|
|
|
|
2016-03-22 15:11:23 -04:00
|
|
|
class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
Model fitted by NaiveBayes.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.5.0
|
2015-07-31 02:03:48 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def pi(self):
|
|
|
|
"""
|
|
|
|
log of class priors.
|
|
|
|
"""
|
|
|
|
return self._call_java("pi")
|
|
|
|
|
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.5.0")
|
2015-07-31 02:03:48 -04:00
|
|
|
def theta(self):
|
|
|
|
"""
|
|
|
|
log of class conditional probabilities.
|
|
|
|
"""
|
|
|
|
return self._call_java("theta")
|
|
|
|
|
|
|
|
|
2015-09-11 11:52:28 -04:00
|
|
|
@inherit_doc
|
|
|
|
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
|
|
|
HasMaxIter, HasTol, HasSeed):
|
|
|
|
"""
|
|
|
|
Classifier trainer based on the Multilayer Perceptron.
|
|
|
|
Each layer has sigmoid activation function, output layer has softmax.
|
|
|
|
Number of inputs has to be equal to the size of feature vectors.
|
|
|
|
Number of outputs has to be equal to the total number of labels.
|
|
|
|
|
|
|
|
>>> from pyspark.mllib.linalg import Vectors
|
|
|
|
>>> df = sqlContext.createDataFrame([
|
|
|
|
... (0.0, Vectors.dense([0.0, 0.0])),
|
|
|
|
... (1.0, Vectors.dense([0.0, 1.0])),
|
|
|
|
... (1.0, Vectors.dense([1.0, 0.0])),
|
|
|
|
... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
|
|
|
|
>>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11)
|
|
|
|
>>> model = mlp.fit(df)
|
|
|
|
>>> model.layers
|
|
|
|
[2, 5, 2]
|
|
|
|
>>> model.weights.size
|
|
|
|
27
|
|
|
|
>>> testDF = sqlContext.createDataFrame([
|
|
|
|
... (Vectors.dense([1.0, 0.0]),),
|
|
|
|
... (Vectors.dense([0.0, 0.0]),)], ["features"])
|
|
|
|
>>> model.transform(testDF).show()
|
|
|
|
+---------+----------+
|
|
|
|
| features|prediction|
|
|
|
|
+---------+----------+
|
|
|
|
|[1.0,0.0]| 1.0|
|
|
|
|
|[0.0,0.0]| 0.0|
|
|
|
|
+---------+----------+
|
|
|
|
...
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.6.0
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
|
|
|
|
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"neurons and output layer of 10 neurons, default is [1, 1].",
|
|
|
|
typeConverter=TypeConverters.toListInt)
|
2015-09-11 11:52:28 -04:00
|
|
|
blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " +
|
|
|
|
"matrices. Data is stacked within partitions. If block size is more than " +
|
|
|
|
"remaining data in a partition then it is adjusted to the size of this " +
|
2016-03-23 14:20:44 -04:00
|
|
|
"data. Recommended size is between 10 and 1000, default is 128.",
|
|
|
|
typeConverter=TypeConverters.toInt)
|
2015-09-11 11:52:28 -04:00
|
|
|
|
|
|
|
@keyword_only
|
|
|
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128):
|
|
|
|
"""
|
|
|
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128)
|
|
|
|
"""
|
|
|
|
super(MultilayerPerceptronClassifier, self).__init__()
|
|
|
|
self._java_obj = self._new_java_obj(
|
|
|
|
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
|
|
|
|
self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128)
|
|
|
|
kwargs = self.__init__._input_kwargs
|
|
|
|
self.setParams(**kwargs)
|
|
|
|
|
|
|
|
@keyword_only
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
|
|
|
maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128):
|
|
|
|
"""
|
|
|
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
|
|
|
maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128)
|
|
|
|
Sets params for MultilayerPerceptronClassifier.
|
|
|
|
"""
|
|
|
|
kwargs = self.setParams._input_kwargs
|
|
|
|
if layers is None:
|
|
|
|
return self._set(**kwargs).setLayers([1, 1])
|
|
|
|
else:
|
|
|
|
return self._set(**kwargs)
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
return MultilayerPerceptronClassificationModel(java_model)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setLayers(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`layers`.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.layers] = value
|
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def getLayers(self):
|
|
|
|
"""
|
|
|
|
Gets the value of layers or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.layers)
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def setBlockSize(self, value):
|
|
|
|
"""
|
|
|
|
Sets the value of :py:attr:`blockSize`.
|
|
|
|
"""
|
|
|
|
self._paramMap[self.blockSize] = value
|
|
|
|
return self
|
|
|
|
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def getBlockSize(self):
|
|
|
|
"""
|
|
|
|
Gets the value of blockSize or its default value.
|
|
|
|
"""
|
|
|
|
return self.getOrDefault(self.blockSize)
|
|
|
|
|
|
|
|
|
|
|
|
class MultilayerPerceptronClassificationModel(JavaModel):
|
|
|
|
"""
|
|
|
|
Model fitted by MultilayerPerceptronClassifier.
|
2015-11-09 16:16:04 -05:00
|
|
|
|
|
|
|
.. versionadded:: 1.6.0
|
2015-09-11 11:52:28 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def layers(self):
|
|
|
|
"""
|
|
|
|
array of layer sizes including input and output layers.
|
|
|
|
"""
|
|
|
|
return self._call_java("javaLayers")
|
|
|
|
|
|
|
|
@property
|
2015-11-09 16:16:04 -05:00
|
|
|
@since("1.6.0")
|
2015-09-11 11:52:28 -04:00
|
|
|
def weights(self):
|
|
|
|
"""
|
|
|
|
vector of initial weights for the model that consists of the weights of layers.
|
|
|
|
"""
|
|
|
|
return self._call_java("weights")
|
|
|
|
|
|
|
|
|
2015-01-28 20:14:23 -05:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import doctest
|
2016-03-16 17:21:42 -04:00
|
|
|
import pyspark.ml.classification
|
2015-01-28 20:14:23 -05:00
|
|
|
from pyspark.context import SparkContext
|
|
|
|
from pyspark.sql import SQLContext
|
2016-03-16 17:21:42 -04:00
|
|
|
globs = pyspark.ml.classification.__dict__.copy()
|
2015-01-28 20:14:23 -05:00
|
|
|
# The small batch size here ensures that we see multiple batches,
|
|
|
|
# even in these small test examples:
|
2015-05-12 15:17:05 -04:00
|
|
|
sc = SparkContext("local[2]", "ml.classification tests")
|
2015-04-08 16:31:45 -04:00
|
|
|
sqlContext = SQLContext(sc)
|
2015-01-28 20:14:23 -05:00
|
|
|
globs['sc'] = sc
|
2015-04-08 16:31:45 -04:00
|
|
|
globs['sqlContext'] = sqlContext
|
2016-03-16 17:21:42 -04:00
|
|
|
import tempfile
|
|
|
|
temp_path = tempfile.mkdtemp()
|
|
|
|
globs['temp_path'] = temp_path
|
|
|
|
try:
|
|
|
|
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
|
|
|
|
sc.stop()
|
|
|
|
finally:
|
|
|
|
from shutil import rmtree
|
|
|
|
try:
|
|
|
|
rmtree(temp_path)
|
|
|
|
except OSError:
|
|
|
|
pass
|
2015-01-28 20:14:23 -05:00
|
|
|
if failure_count:
|
|
|
|
exit(-1)
|