spark-instrumented-optimizer/python/pyspark/ml/wrapper.py
Joseph K. Bradley 4a17eedb16 [SPARK-5867] [SPARK-5892] [doc] [ml] [mllib] Doc cleanups for 1.3 release
For SPARK-5867:
* The spark.ml programming guide needs to be updated to use the new SQL DataFrame API instead of the old SchemaRDD API.
* It should also include Python examples now.

For SPARK-5892:
* Fix Python docs
* Various other cleanups

BTW, I accidentally merged this with master.  If you want to compile it on your own, use this branch which is based on spark/branch-1.3 and cherry-picks the commits from this PR: [https://github.com/jkbradley/spark/tree/doc-review-1.3-check]

CC: mengxr  (ML),  davies  (Python docs)

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4675 from jkbradley/doc-review-1.3 and squashes the following commits:

f191bb0 [Joseph K. Bradley] small cleanups
e786efa [Joseph K. Bradley] small doc corrections
6b1ab4a [Joseph K. Bradley] fixed python lint test
946affa [Joseph K. Bradley] Added sample data for ml.MovieLensALS example.  Changed spark.ml Java examples to use DataFrames API instead of sql()
da81558 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into doc-review-1.3
629dbf5 [Joseph K. Bradley] Updated based on code review: * made new page for old migration guides * small fixes * moved inherit_doc in python
b9df7c4 [Joseph K. Bradley] Small cleanups: toDF to toDF(), adding s for string interpolation
34b067f [Joseph K. Bradley] small doc correction
da16aef [Joseph K. Bradley] Fixed python mllib docs
8cce91c [Joseph K. Bradley] GMM: removed old imports, added some doc
695f3f6 [Joseph K. Bradley] partly done trying to fix inherit_doc for class hierarchies in python docs
a72c018 [Joseph K. Bradley] made ChiSqTestResult appear in python docs
b05a80d [Joseph K. Bradley] organize imports. doc cleanups
e572827 [Joseph K. Bradley] updated programming guide for ml and mllib
2015-02-20 02:31:32 -08:00

150 lines
4.5 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABCMeta
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer
from pyspark.mllib.common import inherit_doc
def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
@inherit_doc
class JavaWrapper(Params):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
"""
__metaclass__ = ABCMeta
#: Fully-qualified class name of the wrapped Java component.
_java_class = None
def _java_obj(self):
"""
Returns or creates a Java object.
"""
java_obj = _jvm()
for name in self._java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj()
def _transfer_params_to_java(self, params, java_obj):
"""
Transforms the embedded params and additional params to the
input Java object.
:param params: additional params (overwriting embedded values)
:param java_obj: Java object to receive the params
"""
paramMap = self._merge_params(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])
def _empty_java_param_map(self):
"""
Returns an empty Java ParamMap reference.
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
return paramMap
@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
"""
__metaclass__ = ABCMeta
def _create_model(self, java_model):
"""
Creates a model from the input Java model reference.
"""
return JavaModel(java_model)
def _fit_java(self, dataset, params={}):
"""
Fits a Java model to the input dataset.
:param dataset: input dataset, which is an instance of
:py:class:`pyspark.sql.SchemaRDD`
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
java_obj = self._java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jdf, self._empty_java_param_map())
def fit(self, dataset, params={}):
java_model = self._fit_java(dataset, params)
return self._create_model(java_model)
@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations.
"""
__metaclass__ = ABCMeta
def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
dataset.sql_ctx)
@inherit_doc
class JavaModel(JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.
"""
__metaclass__ = ABCMeta
def __init__(self, java_model):
super(JavaTransformer, self).__init__()
self._java_model = java_model
def _java_obj(self):
return self._java_model