[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml
For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml. This would be nice to include in Spark 1.3 CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #4682 from jkbradley/SPARK-5902 and squashes the following commits: 6f02357 [Joseph K. Bradley] Made transformSchema public 0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
This commit is contained in:
parent
8ca3418e1b
commit
a5fed34355
|
@ -20,7 +20,7 @@ package org.apache.spark.ml
|
|||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.annotation.AlphaComponent
|
||||
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
|
||||
import org.apache.spark.ml.param.{Param, ParamMap}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
|
|||
abstract class PipelineStage extends Serializable with Logging {
|
||||
|
||||
/**
|
||||
* :: DeveloperAPI ::
|
||||
*
|
||||
* Derives the output schema from the input schema and parameters.
|
||||
* The schema describes the columns and types of the data.
|
||||
*
|
||||
* @param schema Input schema to this stage
|
||||
* @param paramMap Parameters passed to this stage
|
||||
* @return Output schema from this stage
|
||||
*/
|
||||
private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
|
||||
@DeveloperApi
|
||||
def transformSchema(schema: StructType, paramMap: ParamMap): StructType
|
||||
|
||||
/**
|
||||
* Derives the output schema from the input schema and parameters, optionally with logging.
|
||||
|
@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
|
|||
new PipelineModel(this, map, transformers.toArray)
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
val map = this.paramMap ++ paramMap
|
||||
val theStages = map(stages)
|
||||
require(theStages.toSet.size == theStages.size,
|
||||
|
@ -171,7 +179,7 @@ class PipelineModel private[ml] (
|
|||
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
|
||||
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
|
||||
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
|
||||
|
|
|
@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
|
|||
model
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
val map = this.paramMap ++ paramMap
|
||||
val inputType = schema(map(inputCol)).dataType
|
||||
require(inputType.isInstanceOf[VectorUDT],
|
||||
|
@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
|
|||
dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
val map = this.paramMap ++ paramMap
|
||||
val inputType = schema(map(inputCol)).dataType
|
||||
require(inputType.isInstanceOf[VectorUDT],
|
||||
|
|
|
@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
|
|||
@DeveloperApi
|
||||
protected def featuresDataType: DataType = new VectorUDT
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
|
||||
}
|
||||
|
||||
|
@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
|
|||
@DeveloperApi
|
||||
protected def featuresDataType: DataType = new VectorUDT
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
|
||||
}
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ class ALSModel private[ml] (
|
|||
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
|
||||
}
|
||||
|
||||
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
validateAndTransformSchema(schema, paramMap)
|
||||
}
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
|
|||
model
|
||||
}
|
||||
|
||||
override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
validateAndTransformSchema(schema, paramMap)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
|
|||
cvModel
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
val map = this.paramMap ++ paramMap
|
||||
map(estimator).transformSchema(schema, paramMap)
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
|
|||
bestModel.transform(dataset, paramMap)
|
||||
}
|
||||
|
||||
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
|
||||
bestModel.transformSchema(schema, paramMap)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue