[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml

For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml.  This would be nice to include in Spark 1.3

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4682 from jkbradley/SPARK-5902 and squashes the following commits:

6f02357 [Joseph K. Bradley] Made transformSchema public
0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well
fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
This commit is contained in:
Joseph K. Bradley 2015-02-19 12:46:27 -08:00 committed by Xiangrui Meng
parent 8ca3418e1b
commit a5fed34355
5 changed files with 20 additions and 12 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
* :: DeveloperAPI ::
*
* Derives the output schema from the input schema and parameters.
* The schema describes the columns and types of the data.
*
* @param schema Input schema to this stage
* @param paramMap Parameters passed to this stage
* @return Output schema from this stage
*/
private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
@DeveloperApi
def transformSchema(schema: StructType, paramMap: ParamMap): StructType
/**
* Derives the output schema from the input schema and parameters, optionally with logging.
@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
new PipelineModel(this, map, transformers.toArray)
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
@ -171,7 +179,7 @@ class PipelineModel private[ml] (
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))

View file

@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
model
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],

View file

@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
}
@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
}

View file

@ -188,7 +188,7 @@ class ALSModel private[ml] (
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
model
}
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

View file

@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
cvModel
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transformSchema(schema, paramMap)
}
@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
bestModel.transform(dataset, paramMap)
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transformSchema(schema, paramMap)
}
}