2015-01-28 20:14:23 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
from abc import ABCMeta
|
|
|
|
|
|
|
|
from pyspark import SparkContext
|
|
|
|
from pyspark.sql import DataFrame
|
|
|
|
from pyspark.ml.param import Params
|
2015-05-05 14:45:37 -04:00
|
|
|
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
|
2015-02-20 05:31:32 -05:00
|
|
|
from pyspark.mllib.common import inherit_doc
|
2015-01-28 20:14:23 -05:00
|
|
|
|
|
|
|
|
|
|
|
def _jvm():
|
|
|
|
"""
|
|
|
|
Returns the JVM view associated with SparkContext. Must be called
|
|
|
|
after SparkContext is initialized.
|
|
|
|
"""
|
|
|
|
jvm = SparkContext._jvm
|
|
|
|
if jvm:
|
|
|
|
return jvm
|
|
|
|
else:
|
|
|
|
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaWrapper(Params):
|
|
|
|
"""
|
|
|
|
Utility class to help create wrapper classes from Java/Scala
|
|
|
|
implementations of pipeline components.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
#: Fully-qualified class name of the wrapped Java component.
|
|
|
|
_java_class = None
|
|
|
|
|
|
|
|
def _java_obj(self):
|
|
|
|
"""
|
|
|
|
Returns or creates a Java object.
|
|
|
|
"""
|
|
|
|
java_obj = _jvm()
|
|
|
|
for name in self._java_class.split("."):
|
|
|
|
java_obj = getattr(java_obj, name)
|
|
|
|
return java_obj()
|
|
|
|
|
|
|
|
def _transfer_params_to_java(self, params, java_obj):
|
|
|
|
"""
|
|
|
|
Transforms the embedded params and additional params to the
|
|
|
|
input Java object.
|
|
|
|
:param params: additional params (overwriting embedded values)
|
|
|
|
:param java_obj: Java object to receive the params
|
|
|
|
"""
|
2015-04-16 02:49:42 -04:00
|
|
|
paramMap = self.extractParamMap(params)
|
2015-01-28 20:14:23 -05:00
|
|
|
for param in self.params:
|
|
|
|
if param in paramMap:
|
|
|
|
java_obj.set(param.name, paramMap[param])
|
|
|
|
|
|
|
|
def _empty_java_param_map(self):
|
|
|
|
"""
|
|
|
|
Returns an empty Java ParamMap reference.
|
|
|
|
"""
|
|
|
|
return _jvm().org.apache.spark.ml.param.ParamMap()
|
|
|
|
|
|
|
|
def _create_java_param_map(self, params, java_obj):
|
|
|
|
paramMap = self._empty_java_param_map()
|
|
|
|
for param, value in params.items():
|
|
|
|
if param.parent is self:
|
|
|
|
paramMap.put(java_obj.getParam(param.name), value)
|
|
|
|
return paramMap
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaEstimator(Estimator, JavaWrapper):
|
|
|
|
"""
|
|
|
|
Base class for :py:class:`Estimator`s that wrap Java/Scala
|
|
|
|
implementations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
def _create_model(self, java_model):
|
|
|
|
"""
|
|
|
|
Creates a model from the input Java model reference.
|
|
|
|
"""
|
|
|
|
return JavaModel(java_model)
|
|
|
|
|
|
|
|
def _fit_java(self, dataset, params={}):
|
|
|
|
"""
|
|
|
|
Fits a Java model to the input dataset.
|
|
|
|
:param dataset: input dataset, which is an instance of
|
2015-03-09 16:29:19 -04:00
|
|
|
:py:class:`pyspark.sql.DataFrame`
|
2015-01-28 20:14:23 -05:00
|
|
|
:param params: additional params (overwriting embedded values)
|
|
|
|
:return: fitted Java model
|
|
|
|
"""
|
|
|
|
java_obj = self._java_obj()
|
|
|
|
self._transfer_params_to_java(params, java_obj)
|
|
|
|
return java_obj.fit(dataset._jdf, self._empty_java_param_map())
|
|
|
|
|
|
|
|
def fit(self, dataset, params={}):
|
|
|
|
java_model = self._fit_java(dataset, params)
|
|
|
|
return self._create_model(java_model)
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaTransformer(Transformer, JavaWrapper):
|
|
|
|
"""
|
|
|
|
Base class for :py:class:`Transformer`s that wrap Java/Scala
|
|
|
|
implementations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
def transform(self, dataset, params={}):
|
|
|
|
java_obj = self._java_obj()
|
|
|
|
self._transfer_params_to_java({}, java_obj)
|
|
|
|
java_param_map = self._create_java_param_map(params, java_obj)
|
|
|
|
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
|
|
|
|
dataset.sql_ctx)
|
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaModel(JavaTransformer):
|
|
|
|
"""
|
|
|
|
Base class for :py:class:`Model`s that wrap Java/Scala
|
|
|
|
implementations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
def __init__(self, java_model):
|
|
|
|
super(JavaTransformer, self).__init__()
|
|
|
|
self._java_model = java_model
|
|
|
|
|
|
|
|
def _java_obj(self):
|
|
|
|
return self._java_model
|
2015-05-05 14:45:37 -04:00
|
|
|
|
|
|
|
|
|
|
|
@inherit_doc
|
|
|
|
class JavaEvaluator(Evaluator, JavaWrapper):
|
|
|
|
"""
|
|
|
|
Base class for :py:class:`Evaluator`s that wrap Java/Scala
|
|
|
|
implementations.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
def evaluate(self, dataset, params={}):
|
|
|
|
java_obj = self._java_obj()
|
|
|
|
self._transfer_params_to_java(params, java_obj)
|
|
|
|
return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
|