[SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression
## What changes were proposed in this pull request? * ```GeneralizedLinearRegression``` API docs enhancement. * The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol``` * Make some methods more private. * Fix a minor bug of LinearRegression. * Fix some other issues. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13129 from yanboliang/spark-15339.
This commit is contained in:
parent
5e203505f1
commit
c94b34ebbf
|
@ -89,8 +89,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
|
|||
def getQuantilesCol: String = $(quantilesCol)
|
||||
|
||||
/** Checks whether the input has quantiles column name. */
|
||||
protected[regression] def hasQuantilesCol: Boolean = {
|
||||
isDefined(quantilesCol) && $(quantilesCol) != ""
|
||||
private[regression] def hasQuantilesCol: Boolean = {
|
||||
isDefined(quantilesCol) && $(quantilesCol).nonEmpty
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
|
||||
with HasSolver with Logging {
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
/**
|
||||
* Param for the name of family which is a description of the error distribution
|
||||
* to be used in the model.
|
||||
|
@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
@Since("2.0.0")
|
||||
final val family: Param[String] = new Param(this, "family",
|
||||
"The name of family which is a description of the error distribution to be used in the " +
|
||||
"model. Supported options: gaussian(default), binomial, poisson and gamma.",
|
||||
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
|
||||
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
|
||||
ParamValidators.inArray[String](supportedFamilyNames.toArray))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
|
@ -71,9 +73,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
@Since("2.0.0")
|
||||
final val link: Param[String] = new Param(this, "link", "The name of link function " +
|
||||
"which provides the relationship between the linear predictor and the mean of the " +
|
||||
"distribution function. Supported options: identity, log, inverse, logit, probit, " +
|
||||
"cloglog and sqrt.",
|
||||
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
|
||||
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
|
||||
ParamValidators.inArray[String](supportedLinkNames.toArray))
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
|
@ -81,19 +82,23 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
|
||||
/**
|
||||
* Param for link prediction (linear predictor) column name.
|
||||
* Default is empty, which means we do not output link prediction.
|
||||
* Default is not set, which means we do not output link prediction.
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
|
||||
"link prediction (linear predictor) column name")
|
||||
setDefault(linkPredictionCol, "")
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.0.0")
|
||||
def getLinkPredictionCol: String = $(linkPredictionCol)
|
||||
|
||||
/** Checks whether we should output link prediction. */
|
||||
private[regression] def hasLinkPredictionCol: Boolean = {
|
||||
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
|
||||
}
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
|
||||
@Since("2.0.0")
|
||||
|
@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
s"with ${$(family)} family does not support ${$(link)} link function.")
|
||||
}
|
||||
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
|
||||
if ($(linkPredictionCol).nonEmpty) {
|
||||
if (hasLinkPredictionCol) {
|
||||
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
|
||||
} else {
|
||||
newSchema
|
||||
|
@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
/**
|
||||
* Sets the value of param [[weightCol]].
|
||||
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||
* Default is empty, so all instances have weight one.
|
||||
* Default is not set, so all instances have weight one.
|
||||
*
|
||||
* @group setParam
|
||||
*/
|
||||
|
@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
|
||||
/**
|
||||
* Sets the solver algorithm used for optimization.
|
||||
* Currently only support "irls" which is also the default solver.
|
||||
* Currently only supports "irls" which is also the default solver.
|
||||
*
|
||||
* @group setParam
|
||||
*/
|
||||
|
@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
}
|
||||
val familyAndLink = new FamilyAndLink(familyObj, linkObj)
|
||||
|
||||
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
|
||||
.map { case Row(features: Vector) =>
|
||||
features.size
|
||||
}.first()
|
||||
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
|
||||
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
|
||||
val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
|
||||
s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
|
||||
|
@ -294,7 +296,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def load(path: String): GeneralizedLinearRegression = super.load(path)
|
||||
|
||||
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
|
||||
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
|
||||
private[regression] lazy val supportedFamilyAndLinkPairs = Set(
|
||||
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
|
||||
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
|
||||
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
|
||||
|
@ -302,17 +304,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
)
|
||||
|
||||
/** Set of family names that GeneralizedLinearRegression supports. */
|
||||
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
|
||||
private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
|
||||
|
||||
/** Set of link names that GeneralizedLinearRegression supports. */
|
||||
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
|
||||
private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
|
||||
|
||||
private[ml] val epsilon: Double = 1E-16
|
||||
private[regression] val epsilon: Double = 1E-16
|
||||
|
||||
/**
|
||||
* Wrapper of family and link combination used in the model.
|
||||
*/
|
||||
private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
|
||||
private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
|
||||
|
||||
/** Linear predictor based on given mu. */
|
||||
def predict(mu: Double): Double = link.link(family.project(mu))
|
||||
|
@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
*
|
||||
* @param name the name of the family.
|
||||
*/
|
||||
private[ml] abstract class Family(val name: String) extends Serializable {
|
||||
private[regression] abstract class Family(val name: String) extends Serializable {
|
||||
|
||||
/** The default link instance of this family. */
|
||||
val defaultLink: Link
|
||||
|
@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
def project(mu: Double): Double = mu
|
||||
}
|
||||
|
||||
private[ml] object Family {
|
||||
private[regression] object Family {
|
||||
|
||||
/**
|
||||
* Gets the [[Family]] object from its name.
|
||||
|
@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
* Gaussian exponential family distribution.
|
||||
* The default link for the Gaussian family is the identity link.
|
||||
*/
|
||||
private[ml] object Gaussian extends Family("gaussian") {
|
||||
private[regression] object Gaussian extends Family("gaussian") {
|
||||
|
||||
val defaultLink: Link = Identity
|
||||
|
||||
|
@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
* Binomial exponential family distribution.
|
||||
* The default link for the Binomial family is the logit link.
|
||||
*/
|
||||
private[ml] object Binomial extends Family("binomial") {
|
||||
private[regression] object Binomial extends Family("binomial") {
|
||||
|
||||
val defaultLink: Link = Logit
|
||||
|
||||
|
@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
* Poisson exponential family distribution.
|
||||
* The default link for the Poisson family is the log link.
|
||||
*/
|
||||
private[ml] object Poisson extends Family("poisson") {
|
||||
private[regression] object Poisson extends Family("poisson") {
|
||||
|
||||
val defaultLink: Link = Log
|
||||
|
||||
|
@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
* Gamma exponential family distribution.
|
||||
* The default link for the Gamma family is the inverse link.
|
||||
*/
|
||||
private[ml] object Gamma extends Family("gamma") {
|
||||
private[regression] object Gamma extends Family("gamma") {
|
||||
|
||||
val defaultLink: Link = Inverse
|
||||
|
||||
|
@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
*
|
||||
* @param name the name of link function.
|
||||
*/
|
||||
private[ml] abstract class Link(val name: String) extends Serializable {
|
||||
private[regression] abstract class Link(val name: String) extends Serializable {
|
||||
|
||||
/** The link function. */
|
||||
def link(mu: Double): Double
|
||||
|
@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
def unlink(eta: Double): Double
|
||||
}
|
||||
|
||||
private[ml] object Link {
|
||||
private[regression] object Link {
|
||||
|
||||
/**
|
||||
* Gets the [[Link]] object from its name.
|
||||
|
@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
}
|
||||
}
|
||||
|
||||
private[ml] object Identity extends Link("identity") {
|
||||
private[regression] object Identity extends Link("identity") {
|
||||
|
||||
override def link(mu: Double): Double = mu
|
||||
|
||||
|
@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = eta
|
||||
}
|
||||
|
||||
private[ml] object Logit extends Link("logit") {
|
||||
private[regression] object Logit extends Link("logit") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
|
||||
|
||||
|
@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
|
||||
}
|
||||
|
||||
private[ml] object Log extends Link("log") {
|
||||
private[regression] object Log extends Link("log") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(mu)
|
||||
|
||||
|
@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = math.exp(eta)
|
||||
}
|
||||
|
||||
private[ml] object Inverse extends Link("inverse") {
|
||||
private[regression] object Inverse extends Link("inverse") {
|
||||
|
||||
override def link(mu: Double): Double = 1.0 / mu
|
||||
|
||||
|
@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = 1.0 / eta
|
||||
}
|
||||
|
||||
private[ml] object Probit extends Link("probit") {
|
||||
private[regression] object Probit extends Link("probit") {
|
||||
|
||||
override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
|
||||
|
||||
|
@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
|
||||
}
|
||||
|
||||
private[ml] object CLogLog extends Link("cloglog") {
|
||||
private[regression] object CLogLog extends Link("cloglog") {
|
||||
|
||||
override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
|
||||
|
||||
|
@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
|
||||
}
|
||||
|
||||
private[ml] object Sqrt extends Link("sqrt") {
|
||||
private[regression] object Sqrt extends Link("sqrt") {
|
||||
|
||||
override def link(mu: Double): Double = math.sqrt(mu)
|
||||
|
||||
|
@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] (
|
|||
if ($(predictionCol).nonEmpty) {
|
||||
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
|
||||
}
|
||||
if ($(linkPredictionCol).nonEmpty) {
|
||||
if (hasLinkPredictionCol) {
|
||||
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
|
||||
}
|
||||
output.toDF()
|
||||
|
@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
|
|||
*/
|
||||
@Since("2.0.0")
|
||||
val predictionCol: String = {
|
||||
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
|
||||
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) {
|
||||
origModel.getPredictionCol
|
||||
} else {
|
||||
"prediction_" + java.util.UUID.randomUUID.toString
|
||||
|
|
|
@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
|
|||
setDefault(isotonic -> true, featureIndex -> 0)
|
||||
|
||||
/** Checks whether the input has weight column. */
|
||||
protected[ml] def hasWeightCol: Boolean = {
|
||||
isDefined(weightCol) && $(weightCol) != ""
|
||||
private[regression] def hasWeightCol: Boolean = {
|
||||
isDefined(weightCol) && $(weightCol).nonEmpty
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -161,9 +161,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
|
||||
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
|
||||
// Extract the number of features before deciding optimization solver.
|
||||
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
|
||||
case Row(features: Vector) => features.size
|
||||
}.first()
|
||||
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
|
||||
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
|
||||
|
@ -242,7 +240,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
val coefficients = Vectors.sparse(numFeatures, Seq())
|
||||
val intercept = yMean
|
||||
|
||||
val model = new LinearRegressionModel(uid, coefficients, intercept)
|
||||
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
|
||||
// Handle possible missing or invalid prediction columns
|
||||
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
|
||||
|
||||
|
@ -254,7 +252,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
model,
|
||||
Array(0D),
|
||||
Array(0D))
|
||||
return copyValues(model.setSummary(trainingSummary))
|
||||
return model.setSummary(trainingSummary)
|
||||
} else {
|
||||
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
|
||||
"Model cannot be regularized.")
|
||||
|
|
|
@ -610,20 +610,31 @@ class LinearRegressionSuite
|
|||
val model1 = new LinearRegression()
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setWeightCol("weight")
|
||||
.setPredictionCol("myPrediction")
|
||||
.setSolver(solver)
|
||||
.fit(datasetWithWeightConstantLabel)
|
||||
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
|
||||
model1.coefficients(1))
|
||||
assert(actual1 ~== expected(idx) absTol 1e-4)
|
||||
|
||||
// Schema of summary.predictions should be a superset of the input dataset
|
||||
assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol)
|
||||
.subsetOf(model1.summary.predictions.schema.fieldNames.toSet))
|
||||
|
||||
val model2 = new LinearRegression()
|
||||
.setFitIntercept(fitIntercept)
|
||||
.setWeightCol("weight")
|
||||
.setPredictionCol("myPrediction")
|
||||
.setSolver(solver)
|
||||
.fit(datasetWithWeightZeroLabel)
|
||||
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
|
||||
model2.coefficients(1))
|
||||
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
|
||||
|
||||
// Schema of summary.predictions should be a superset of the input dataset
|
||||
assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol)
|
||||
.subsetOf(model2.summary.predictions.schema.fieldNames.toSet))
|
||||
|
||||
idx += 1
|
||||
}
|
||||
}
|
||||
|
@ -672,7 +683,7 @@ class LinearRegressionSuite
|
|||
|
||||
test("linear regression model training summary") {
|
||||
Seq("auto", "l-bfgs", "normal").foreach { solver =>
|
||||
val trainer = new LinearRegression().setSolver(solver)
|
||||
val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction")
|
||||
val model = trainer.fit(datasetWithDenseFeature)
|
||||
val trainerNoPredictionCol = trainer.setPredictionCol("")
|
||||
val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature)
|
||||
|
@ -682,7 +693,7 @@ class LinearRegressionSuite
|
|||
assert(modelNoPredictionCol.hasSummary)
|
||||
|
||||
// Schema should be a superset of the input dataset
|
||||
assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf(
|
||||
assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf(
|
||||
model.summary.predictions.schema.fieldNames.toSet))
|
||||
// Validate that we re-insert a prediction column for evaluation
|
||||
val modelNoPredictionColFieldNames
|
||||
|
|
Loading…
Reference in a new issue