[SPARK-30660][ML][PYSPARK] LinearRegression blockify input vectors

### What changes were proposed in this pull request?
1, add new param blockSize;
2, add a new class InstanceBlock;
3, if blockSize==1, keep original behavior; if blockSize>1, stack input vectors to blocks (like ALS/MLP);
4, if blockSize>1, standardize the input outside of optimization procedure;

### Why are the changes needed?
it will obtain performance gain on dense datasets, such as `epsilon`
1, reduce RAM to persist traing dataset; (save about 40% RAM)
2, use Level-2 BLAS routines;  (up to 6X(squaredError)~12X(huber) speedup)

### Does this PR introduce _any_ user-facing change?
Yes, a new param is added

### How was this patch tested?
existing and added testsuites

Closes #28471 from zhengruifeng/blockify_lir_II.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
zhengruifeng 2020-05-08 10:52:01 +08:00
parent 24fac1e0c7
commit 97332f26bf
12 changed files with 605 additions and 218 deletions

View file

@ -281,7 +281,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
values
}
private[ml] def getTransformFunc(
private[spark] def getTransformFunc(
shift: Array[Double],
scale: Array[Double],
withShift: Boolean,

View file

@ -17,8 +17,8 @@
package org.apache.spark.ml.optim.aggregator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg._
/**
* HuberAggregator computes the gradient and loss for a huber loss function,
@ -74,15 +74,12 @@ private[ml] class HuberAggregator(
extends DifferentiableLossAggregator[Instance, HuberAggregator] {
protected override val dim: Int = bcParameters.value.size
private val numFeatures: Int = if (fitIntercept) dim - 2 else dim - 1
private val sigma: Double = bcParameters.value(dim - 1)
private val intercept: Double = if (fitIntercept) {
bcParameters.value(dim - 2)
} else {
0.0
}
private val numFeatures = if (fitIntercept) dim - 2 else dim - 1
private val sigma = bcParameters.value(dim - 1)
private val intercept = if (fitIntercept) bcParameters.value(dim - 2) else 0.0
// make transient so we do not serialize between aggregation stages
@transient private lazy val coefficients = bcParameters.value.toArray.slice(0, numFeatures)
@transient private lazy val coefficients = bcParameters.value.toArray.take(numFeatures)
/**
* Add a new training instance to this HuberAggregator, and update the loss and gradient
@ -150,3 +147,101 @@ private[ml] class HuberAggregator(
}
}
}
/**
* BlockHuberAggregator computes the gradient and loss for Huber loss function
* as used in linear regression for blocks in sparse or dense matrix in an online fashion.
*
* Two BlockHuberAggregators can be merged together to have a summary of loss and gradient
* of the corresponding joint dataset.
*
* NOTE: The feature values are expected to be standardized before computation.
*
* @param fitIntercept Whether to fit an intercept term.
*/
private[ml] class BlockHuberAggregator(
fitIntercept: Boolean,
epsilon: Double)(bcParameters: Broadcast[Vector])
extends DifferentiableLossAggregator[InstanceBlock, BlockHuberAggregator] {
protected override val dim: Int = bcParameters.value.size
private val numFeatures = if (fitIntercept) dim - 2 else dim - 1
private val sigma = bcParameters.value(dim - 1)
private val intercept = if (fitIntercept) bcParameters.value(dim - 2) else 0.0
// make transient so we do not serialize between aggregation stages
@transient private lazy val linear = Vectors.dense(bcParameters.value.toArray.take(numFeatures))
/**
* Add a new training instance block to this BlockHuberAggregator, and update the loss and
* gradient of the objective function.
*
* @param block The instance block of data point to be added.
* @return This BlockHuberAggregator object.
*/
def add(block: InstanceBlock): BlockHuberAggregator = {
require(block.matrix.isTransposed)
require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
require(block.weightIter.forall(_ >= 0),
s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")
if (block.weightIter.forall(_ == 0)) return this
val size = block.size
// vec here represents margins or dotProducts
val vec = if (fitIntercept) {
Vectors.dense(Array.fill(size)(intercept)).toDense
} else {
Vectors.zeros(size).toDense
}
BLAS.gemv(1.0, block.matrix, linear, 1.0, vec)
// in-place convert margins to multipliers
// then, vec represents multipliers
var sigmaGradSum = 0.0
var i = 0
while (i < size) {
val weight = block.getWeight(i)
if (weight > 0) {
weightSum += weight
val label = block.getLabel(i)
val margin = vec(i)
val linearLoss = label - margin
if (math.abs(linearLoss) <= sigma * epsilon) {
lossSum += 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma)
val linearLossDivSigma = linearLoss / sigma
val multiplier = -1.0 * weight * linearLossDivSigma
vec.values(i) = multiplier
sigmaGradSum += 0.5 * weight * (1.0 - math.pow(linearLossDivSigma, 2.0))
} else {
lossSum += 0.5 * weight *
(sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon)
val sign = if (linearLoss >= 0) -1.0 else 1.0
val multiplier = weight * sign * epsilon
vec.values(i) = multiplier
sigmaGradSum += 0.5 * weight * (1.0 - epsilon * epsilon)
}
} else { vec.values(i) = 0.0 }
i += 1
}
block.matrix match {
case dm: DenseMatrix =>
BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
vec.values, 1, 1.0, gradientSumArray, 1)
case sm: SparseMatrix =>
val linearGradSumVec = Vectors.zeros(numFeatures).toDense
BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
gradientSumArray, 1)
}
gradientSumArray(dim - 1) += sigmaGradSum
if (fitIntercept) gradientSumArray(dim - 2) += vec.values.sum
this
}
}

View file

@ -17,8 +17,8 @@
package org.apache.spark.ml.optim.aggregator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg._
/**
* LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
@ -222,3 +222,92 @@ private[ml] class LeastSquaresAggregator(
}
}
}
/**
* BlockLeastSquaresAggregator computes the gradient and loss for LeastSquares loss function
* as used in linear regression for blocks in sparse or dense matrix in an online fashion.
*
* Two BlockLeastSquaresAggregators can be merged together to have a summary of loss and gradient
* of the corresponding joint dataset.
*
* NOTE: The feature values are expected to be standardized before computation.
*
* @param bcCoefficients The coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
*/
private[ml] class BlockLeastSquaresAggregator(
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector])
extends DifferentiableLossAggregator[InstanceBlock, BlockLeastSquaresAggregator] {
require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " +
s"deviation to be positive.")
private val numFeatures = bcFeaturesStd.value.length
protected override val dim: Int = numFeatures
// make transient so we do not serialize between aggregation stages
@transient private lazy val effectiveCoefAndOffset = {
val coefficientsArray = bcCoefficients.value.toArray.clone()
val featuresMean = bcFeaturesMean.value
val featuresStd = bcFeaturesStd.value
var sum = 0.0
var i = 0
val len = coefficientsArray.length
while (i < len) {
if (featuresStd(i) != 0.0) {
sum += coefficientsArray(i) / featuresStd(i) * featuresMean(i)
} else {
coefficientsArray(i) = 0.0
}
i += 1
}
val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0
(Vectors.dense(coefficientsArray), offset)
}
// do not use tuple assignment above because it will circumvent the @transient tag
@transient private lazy val effectiveCoefficientsVec = effectiveCoefAndOffset._1
@transient private lazy val offset = effectiveCoefAndOffset._2
/**
* Add a new training instance block to this BlockLeastSquaresAggregator, and update the loss
* and gradient of the objective function.
*
* @param block The instance block of data point to be added.
* @return This BlockLeastSquaresAggregator object.
*/
def add(block: InstanceBlock): BlockLeastSquaresAggregator = {
require(block.matrix.isTransposed)
require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
require(block.weightIter.forall(_ >= 0),
s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")
if (block.weightIter.forall(_ == 0)) return this
val size = block.size
// vec here represents diffs
val vec = new DenseVector(Array.tabulate(size)(i => offset - block.getLabel(i) / labelStd))
BLAS.gemv(1.0, block.matrix, effectiveCoefficientsVec, 1.0, vec)
// in-place convert diffs to multipliers
// then, vec represents multipliers
var i = 0
while (i < size) {
val weight = block.getWeight(i)
val diff = vec(i)
lossSum += weight * diff * diff / 2
weightSum += weight
val multiplier = weight * diff
vec.values(i) = multiplier
i += 1
}
val gradSumVec = new DenseVector(gradientSumArray)
BLAS.gemv(1.0, block.matrix.transpose, vec, 1.0, gradSumVec)
this
}
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import breeze.optimize.{CachedDiffFunction, DiffFunction, FirstOrderMinimizer, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import breeze.stats.distributions.StudentsT
import org.apache.hadoop.fs.Path
@ -28,10 +28,11 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PipelineStage, PredictorParams}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.BLAS._
import org.apache.spark.ml.optim.WeightedLeastSquares
import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggregator}
import org.apache.spark.ml.optim.aggregator._
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
@ -42,6 +43,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@ -54,7 +56,7 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
with HasAggregationDepth with HasLoss {
with HasAggregationDepth with HasLoss with HasBlockSize {
import LinearRegression._
@ -315,49 +317,54 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setEpsilon(value: Double): this.type = set(epsilon, value)
setDefault(epsilon -> 1.35)
/**
* Set block size for stacking input data in matrices.
* If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
* If blockSize &gt; 1, then vectors will be stacked to blocks, and high-level BLAS routines
* will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
* Recommended size is between 10 and 1000. An appropriate choice of the block size depends
* on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
* f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
* Note that existing BLAS implementations are mainly optimized for dense matrices, if the
* input dataset is sparse, stacking may bring no performance gain, the worse is possible
* performance regression.
* Default is 1.
*
* @group expertSetParam
*/
@Since("3.1.0")
def setBlockSize(value: Int): this.type = set(blockSize, value)
setDefault(blockSize -> 1)
override protected def train(dataset: Dataset[_]): LinearRegressionModel = instrumented { instr =>
// Extract the number of features before deciding optimization solver.
val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
val instances = extractInstances(dataset)
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, solver, tol,
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth, loss,
epsilon)
epsilon, blockSize)
// Extract the number of features before deciding optimization solver.
val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)
if ($(loss) == SquaredError && (($(solver) == Auto &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal)) {
// For low dimensional data, WeightedLeastSquares is more efficient since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
elasticNetParam = $(elasticNetParam), $(standardization), true,
solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol))
val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr))
// When it is trained by WeightedLeastSquares, training summary does not
// attach returned model.
val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
model.objectiveHistory)
return lrModel.setSummary(Some(trainingSummary))
return trainWithNormal(dataset, instr)
}
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val instances = extractInstances(dataset)
.setName("training instances")
val (featuresSummarizer, ySummarizer) =
val (featuresSummarizer, ySummarizer) = if ($(blockSize) == 1) {
if (dataset.storageLevel == StorageLevel.NONE) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth))
} else {
// instances will be standardized and converted to blocks, so no need to cache instances.
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth),
Seq("mean", "std", "count", "numNonZeros"))
}
val yMean = ySummarizer.mean(0)
val rawYStd = ySummarizer.std(0)
@ -366,40 +373,20 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean)
instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd)
instr.logSumOfWeights(featuresSummarizer.weightSum)
if ($(blockSize) > 1) {
val scale = 1.0 / featuresSummarizer.count / numFeatures
val sparsity = 1 - featuresSummarizer.numNonzeros.toArray.map(_ * scale).sum
instr.logNamedValue("sparsity", sparsity.toString)
if (sparsity > 0.5) {
instr.logWarning(s"sparsity of input dataset is $sparsity, " +
s"which may hurt performance in high-level BLAS.")
}
}
if (rawYStd == 0.0) {
if ($(fitIntercept) || yMean == 0.0) {
// If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with
// zero coefficient; as a result, training is not needed.
// Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of
// the fitIntercept.
if (yMean == 0.0) {
instr.logWarning(s"Mean and standard deviation of the label are zero, so the " +
s"coefficients and the intercept will all be zero; as a result, training is not " +
s"needed.")
} else {
instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " +
s"will be zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
if (handlePersistence) instances.unpersist()
val coefficients = Vectors.sparse(numFeatures, Seq.empty)
val intercept = yMean
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
model,
Array(0D),
Array(0D))
return model.setSummary(Some(trainingSummary))
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
return trainWithConstantLabel(dataset, instr, numFeatures, yMean)
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
@ -413,8 +400,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
val featuresMean = featuresSummarizer.mean.toArray
val featuresStd = featuresSummarizer.std.toArray
val bcFeaturesMean = instances.context.broadcast(featuresMean)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
@ -437,21 +422,105 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
Some(new L2Regularization(effectiveL2RegParam, shouldApply,
if ($(standardization)) None else Some(getFeaturesStd)))
} else {
None
}
} else None
val costFun = $(loss) match {
val optimizer = createOptimizer(effectiveRegParam, effectiveL1RegParam,
numFeatures, featuresStd)
val initialValues = $(loss) match {
case SquaredError =>
val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept),
bcFeaturesStd, bcFeaturesMean)(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth))
Vectors.zeros(numFeatures)
case Huber =>
val getAggregatorFunc = new HuberAggregator($(fitIntercept), $(epsilon), bcFeaturesStd)(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth))
val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
Vectors.dense(Array.fill(dim)(1.0))
}
val optimizer = $(loss) match {
val (parameters, objectiveHistory) = if ($(blockSize) == 1) {
trainOnRows(instances, yMean, yStd, featuresMean, featuresStd,
initialValues, regularization, optimizer)
} else {
trainOnBlocks(instances, yMean, yStd, featuresMean, featuresStd,
initialValues, regularization, optimizer)
}
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
if (parameters == null) {
val msg = s"${optimizer.getClass.getName} failed."
instr.logError(msg)
throw new SparkException(msg)
}
val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd)
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol),
model, Array(0.0), objectiveHistory)
model.setSummary(Some(trainingSummary))
}
private def trainWithNormal(
dataset: Dataset[_],
instr: Instrumentation): LinearRegressionModel = {
// For low dimensional data, WeightedLeastSquares is more efficient since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
elasticNetParam = $(elasticNetParam), $(standardization), true,
solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol))
val instances = extractInstances(dataset)
.setName("training instances")
val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr))
// When it is trained by WeightedLeastSquares, training summary does not
// attach returned model.
val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol),
summaryModel, model.diagInvAtWA.toArray, model.objectiveHistory)
lrModel.setSummary(Some(trainingSummary))
}
private def trainWithConstantLabel(
dataset: Dataset[_],
instr: Instrumentation,
numFeatures: Int,
yMean: Double): LinearRegressionModel = {
// If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with
// zero coefficient; as a result, training is not needed.
// Also, if rawYStd==0 and yMean==0, all the coefficients are zero regardless of
// the fitIntercept.
if (yMean == 0.0) {
instr.logWarning(s"Mean and standard deviation of the label are zero, so the " +
s"coefficients and the intercept will all be zero; as a result, training is not " +
s"needed.")
} else {
instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " +
s"will be zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
val coefficients = Vectors.sparse(numFeatures, Seq.empty)
val intercept = yMean
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol),
model, Array(0.0), Array(0.0))
model.setSummary(Some(trainingSummary))
}
private def createOptimizer(
effectiveRegParam: Double,
effectiveL1RegParam: Double,
numFeatures: Int,
featuresStd: Array[Double]) = {
$(loss) match {
case SquaredError =>
if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@ -479,105 +548,162 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val upperBounds = BDV[Double](Array.fill(dim)(Double.MaxValue))
new BreezeLBFGSB(lowerBounds, upperBounds, $(maxIter), 10, $(tol))
}
}
val initialValues = $(loss) match {
private def trainOnRows(
instances: RDD[Instance],
yMean: Double,
yStd: Double,
featuresMean: Array[Double],
featuresStd: Array[Double],
initialValues: Vector,
regularization: Option[L2Regularization],
optimizer: FirstOrderMinimizer[BDV[Double], DiffFunction[BDV[Double]]]) = {
val bcFeaturesMean = instances.context.broadcast(featuresMean)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = $(loss) match {
case SquaredError =>
Vectors.zeros(numFeatures)
val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept),
bcFeaturesStd, bcFeaturesMean)(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth))
case Huber =>
val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
Vectors.dense(Array.fill(dim)(1.0))
val getAggregatorFunc = new HuberAggregator($(fitIntercept), $(epsilon), bcFeaturesStd)(_)
new RDDLossFunction(instances, getAggregatorFunc, regularization, $(aggregationDepth))
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialValues.asBreeze.toDenseVector)
val (coefficients, intercept, scale, objectiveHistory) = {
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
<blockquote>
$$
L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2
+ regTerms \\
$$
</blockquote>
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
while (states.hasNext) {
state = states.next()
arrayBuilder += state.adjustedValue
}
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
instr.logError(msg)
throw new SparkException(msg)
}
bcFeaturesMean.destroy()
bcFeaturesStd.destroy()
val parameters = state.x.toArray.clone()
/*
The coefficients are trained in the scaled space; we're converting them back to
the original space.
*/
val rawCoefficients: Array[Double] = $(loss) match {
case SquaredError => parameters
case Huber => parameters.slice(0, numFeatures)
}
var i = 0
val len = rawCoefficients.length
val multiplier = $(loss) match {
case SquaredError => yStd
case Huber => 1.0
}
while (i < len) {
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / featuresStd(i) else 0.0 }
i += 1
}
val interceptValue: Double = if ($(fitIntercept)) {
$(loss) match {
case SquaredError =>
/*
The intercept of squared error in R's GLMNET is computed using closed form
after the coefficients are converged. See the following discussion for detail.
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*/
yMean - dot(Vectors.dense(rawCoefficients), Vectors.dense(featuresMean))
case Huber => parameters(numFeatures)
}
} else {
0.0
}
val scaleValue: Double = $(loss) match {
case SquaredError => 1.0
case Huber => parameters.last
}
(Vectors.dense(rawCoefficients).compressed, interceptValue, scaleValue, arrayBuilder.result())
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
<blockquote>
$$
L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2
+ regTerms \\
$$
</blockquote>
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
while (states.hasNext) {
state = states.next()
arrayBuilder += state.adjustedValue
}
if (handlePersistence) instances.unpersist()
bcFeaturesMean.destroy()
bcFeaturesStd.destroy()
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept, scale))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
(if (state == null) null else state.x.toArray, arrayBuilder.result)
}
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
model,
Array(0D),
objectiveHistory)
private def trainOnBlocks(
instances: RDD[Instance],
yMean: Double,
yStd: Double,
featuresMean: Array[Double],
featuresStd: Array[Double],
initialValues: Vector,
regularization: Option[L2Regularization],
optimizer: FirstOrderMinimizer[BDV[Double], DiffFunction[BDV[Double]]]) = {
val bcFeaturesMean = instances.context.broadcast(featuresMean)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
model.setSummary(Some(trainingSummary))
val standardized = instances.mapPartitions { iter =>
val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 }
val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true)
iter.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) }
}
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
.persist(StorageLevel.MEMORY_AND_DISK)
.setName(s"training dataset (blockSize=${$(blockSize)})")
val costFun = $(loss) match {
case SquaredError =>
val getAggregatorFunc = new BlockLeastSquaresAggregator(yStd, yMean, $(fitIntercept),
bcFeaturesStd, bcFeaturesMean)(_)
new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth))
case Huber =>
val getAggregatorFunc = new BlockHuberAggregator($(fitIntercept), $(epsilon))(_)
new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth))
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialValues.asBreeze.toDenseVector)
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
<blockquote>
$$
L &= 1/2n||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2
+ regTerms \\
$$
</blockquote>
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
while (states.hasNext) {
state = states.next()
arrayBuilder += state.adjustedValue
}
blocks.unpersist()
bcFeaturesMean.destroy()
bcFeaturesStd.destroy()
(if (state == null) null else state.x.toArray, arrayBuilder.result)
}
private def createModel(
parameters: Array[Double],
yMean: Double,
yStd: Double,
featuresMean: Array[Double],
featuresStd: Array[Double]): LinearRegressionModel = {
val numFeatures = featuresStd.length
/*
The coefficients are trained in the scaled space; we're converting them back to
the original space.
*/
val rawCoefficients = $(loss) match {
case SquaredError => parameters.clone()
case Huber => parameters.take(numFeatures)
}
var i = 0
val len = rawCoefficients.length
val multiplier = $(loss) match {
case SquaredError => yStd
case Huber => 1.0
}
while (i < len) {
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / featuresStd(i) else 0.0 }
i += 1
}
val intercept = if ($(fitIntercept)) {
$(loss) match {
case SquaredError =>
/*
The intercept of squared error in R's GLMNET is computed using closed form
after the coefficients are converged. See the following discussion for detail.
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*/
yMean - dot(Vectors.dense(rawCoefficients), Vectors.dense(featuresMean))
case Huber => parameters(numFeatures)
}
} else 0.0
val scale = $(loss) match {
case SquaredError => 1.0
case Huber => parameters.last
}
val coefficients = Vectors.dense(rawCoefficients).compressed
copyValues(new LinearRegressionModel(uid, coefficients, intercept, scale))
}
@Since("1.4.0")
@ -655,7 +781,7 @@ class LinearRegressionModel private[ml] (
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
$(labelCol), $(featuresCol), summaryModel, Array(0D))
$(labelCol), $(featuresCol), summaryModel, Array(0.0))
}
/**

View file

@ -208,9 +208,10 @@ object Summarizer extends Logging {
/** Get regression feature and label summarizers for provided data. */
private[ml] def getRegressionSummarizers(
instances: RDD[Instance],
aggregationDepth: Int = 2): (SummarizerBuffer, SummarizerBuffer) = {
aggregationDepth: Int = 2,
requested: Seq[String] = Seq("mean", "std", "count")) = {
instances.treeAggregate(
(Summarizer.createSummarizerBuffer("mean", "std"),
(Summarizer.createSummarizerBuffer(requested: _*),
Summarizer.createSummarizerBuffer("mean", "std", "count")))(
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
(c._1.add(instance.features, instance.weight),
@ -223,7 +224,7 @@ object Summarizer extends Logging {
}
/** Get classification feature and label summarizers for provided data. */
private[ml] def getClassificationSummarizers(
private[spark] def getClassificationSummarizers(
instances: RDD[Instance],
aggregationDepth: Int = 2,
requested: Seq[String] = Seq("mean", "std", "count")) = {

View file

@ -63,22 +63,7 @@ class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
new HingeAggregator(bcFeaturesStd, fitIntercept)(bcCoefficients)
}
private def standardize(instances: Array[Instance]): Array[Instance] = {
val (featuresSummarizer, _) =
Summarizer.getClassificationSummarizers(sc.parallelize(instances))
val stdArray = featuresSummarizer.std.toArray
val numFeatures = stdArray.length
instances.map { case Instance(label, weight, features) =>
val standardized = Array.ofDim[Double](numFeatures)
features.foreachNonZero { (i, v) =>
val std = stdArray(i)
if (std != 0) standardized(i) = v / std
}
Instance(label, weight, Vectors.dense(standardized).compressed)
}
}
/** Get summary statistics for some data and create a new BlockHingeAggregator. */
/** Get summary statistics for some data and create a new BlockHingeAggregator. */
private def getNewBlockAggregator(
coefficients: Vector,
fitIntercept: Boolean): BlockHingeAggregator = {

View file

@ -17,7 +17,7 @@
package org.apache.spark.ml.optim.aggregator
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util.TestingUtils._
@ -28,6 +28,7 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var instances: Array[Instance] = _
@transient var instancesConstantFeature: Array[Instance] = _
@transient var instancesConstantFeatureFiltered: Array[Instance] = _
@transient var standardizedInstances: Array[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
@ -46,6 +47,7 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
Instance(1.0, 0.5, Vectors.dense(1.0)),
Instance(2.0, 0.3, Vectors.dense(0.5))
)
standardizedInstances = standardize(instances)
}
/** Get summary statistics for some data and create a new HuberAggregator. */
@ -61,6 +63,15 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
new HuberAggregator(fitIntercept, epsilon, bcFeaturesStd)(bcParameters)
}
/** Get summary statistics for some data and create a new BlockHingeAggregator. */
private def getNewBlockAggregator(
parameters: Vector,
fitIntercept: Boolean,
epsilon: Double): BlockHuberAggregator = {
val bcParameters = spark.sparkContext.broadcast(parameters)
new BlockHuberAggregator(fitIntercept, epsilon)(bcParameters)
}
test("aggregator add method should check input size") {
val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0)
val agg = getNewAggregator(instances, parameters, fitIntercept = true, epsilon = 1.35)
@ -147,6 +158,23 @@ class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(loss ~== agg.loss relTol 0.01)
assert(gradient ~== agg.gradient relTol 0.01)
Seq(1, 2, 4).foreach { blockSize =>
val blocks1 = standardizedInstances
.grouped(blockSize)
.map(seq => InstanceBlock.fromInstances(seq))
.toArray
val blocks2 = blocks1.map { block =>
new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor)
}
Seq(blocks1, blocks2).foreach { blocks =>
val blockAgg = getNewBlockAggregator(parameters, fitIntercept = true, epsilon)
blocks.foreach(blockAgg.add)
assert(agg.loss ~== blockAgg.loss relTol 1e-9)
assert(agg.gradient ~== blockAgg.gradient relTol 1e-9)
}
}
}
test("check with zero standard deviation") {

View file

@ -17,7 +17,7 @@
package org.apache.spark.ml.optim.aggregator
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util.TestingUtils._
@ -28,6 +28,7 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
@transient var instances: Array[Instance] = _
@transient var instancesConstantFeature: Array[Instance] = _
@transient var instancesConstantLabel: Array[Instance] = _
@transient var standardizedInstances: Array[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
@ -46,6 +47,7 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)),
Instance(1.0, 0.3, Vectors.dense(4.0, 0.5))
)
standardizedInstances = standardize(instances)
}
/** Get summary statistics for some data and create a new LeastSquaresAggregator. */
@ -66,6 +68,24 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
bcFeaturesMean)(bcCoefficients)
}
/** Get summary statistics for some data and create a new BlockHingeAggregator. */
private def getNewBlockAggregator(
instances: Array[Instance],
coefficients: Vector,
fitIntercept: Boolean): BlockLeastSquaresAggregator = {
val (featuresSummarizer, ySummarizer) =
Summarizer.getRegressionSummarizers(sc.parallelize(instances))
val yStd = ySummarizer.std(0)
val yMean = ySummarizer.mean(0)
val featuresStd = featuresSummarizer.std.toArray
val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd)
val featuresMean = featuresSummarizer.mean
val bcFeaturesMean = spark.sparkContext.broadcast(featuresMean.toArray)
val bcCoefficients = spark.sparkContext.broadcast(coefficients)
new BlockLeastSquaresAggregator(yStd, yMean, fitIntercept, bcFeaturesStd,
bcFeaturesMean)(bcCoefficients)
}
test("aggregator add method input size") {
val coefficients = Vectors.dense(1.0, 2.0)
val agg = getNewAggregator(instances, coefficients, fitIntercept = true)
@ -142,6 +162,23 @@ class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkConte
BLAS.scal(1.0 / weightSum, expectedGradient)
assert(agg.loss ~== (expectedLoss.sum / weightSum) relTol 1e-5)
assert(agg.gradient ~== expectedGradient relTol 1e-5)
Seq(1, 2, 4).foreach { blockSize =>
val blocks1 = standardizedInstances
.grouped(blockSize)
.map(seq => InstanceBlock.fromInstances(seq))
.toArray
val blocks2 = blocks1.map { block =>
new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor)
}
Seq(blocks1, blocks2).foreach { blocks =>
val blockAgg = getNewBlockAggregator(instances, coefficients, fitIntercept = true)
blocks.foreach(blockAgg.add)
assert(agg.loss ~== blockAgg.loss relTol 1e-9)
assert(agg.gradient ~== blockAgg.gradient relTol 1e-9)
}
}
}
test("check with zero standard deviation") {

View file

@ -79,21 +79,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
new BlockLogisticAggregator(numFeatures, numClasses, fitIntercept, multinomial)(bcCoefficients)
}
private def standardize(instances: Array[Instance]): Array[Instance] = {
val (featuresSummarizer, _) =
Summarizer.getClassificationSummarizers(sc.parallelize(instances))
val stdArray = featuresSummarizer.std.toArray
val numFeatures = stdArray.length
instances.map { case Instance(label, weight, features) =>
val standardized = Array.ofDim[Double](numFeatures)
features.foreachNonZero { (i, v) =>
val std = stdArray(i)
if (std != 0) standardized(i) = v / std
}
Instance(label, weight, Vectors.dense(standardized).compressed)
}
}
test("aggregator add method input size") {
val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
val interceptArray = Array(4.0, 2.0, -3.0)

View file

@ -660,6 +660,26 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
testPredictionModelSinglePrediction(model, datasetWithDenseFeature)
}
test("LinearRegression on blocks") {
for (dataset <- Seq(datasetWithDenseFeature, datasetWithStrongNoise,
datasetWithDenseFeatureWithoutIntercept, datasetWithSparseFeature, datasetWithWeight,
datasetWithWeightConstantLabel, datasetWithWeightZeroLabel, datasetWithOutlier);
fitIntercept <- Seq(true, false);
loss <- Seq("squaredError", "huber")) {
val lir = new LinearRegression()
.setFitIntercept(fitIntercept)
.setLoss(loss)
.setMaxIter(3)
val model = lir.fit(dataset)
Seq(4, 16, 64).foreach { blockSize =>
val model2 = lir.setBlockSize(blockSize).fit(dataset)
assert(model.intercept ~== model2.intercept relTol 1e-9)
assert(model.coefficients ~== model2.coefficients relTol 1e-9)
assert(model.scale ~== model2.scale relTol 1e-9)
}
}
}
test("linear regression model with constant label") {
/*
R code:

View file

@ -22,6 +22,8 @@ import java.io.File
import org.scalatest.Suite
import org.apache.spark.SparkContext
import org.apache.spark.ml.feature._
import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.util.Utils
@ -66,4 +68,13 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
protected object testImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}
private[spark] def standardize(instances: Array[Instance]): Array[Instance] = {
val (featuresSummarizer, _) =
Summarizer.getClassificationSummarizers(sc.parallelize(instances))
val inverseStd = featuresSummarizer.std.toArray
.map { std => if (std != 0) 1.0 / std else 0.0 }
val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true)
instances.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) }
}
}

View file

@ -87,7 +87,7 @@ class _JavaRegressionModel(RegressionModel, JavaPredictionModel):
class _LinearRegressionParams(_PredictorParams, HasRegParam, HasElasticNetParam, HasMaxIter,
HasTol, HasFitIntercept, HasStandardization, HasWeightCol, HasSolver,
HasAggregationDepth, HasLoss):
HasAggregationDepth, HasLoss, HasBlockSize):
"""
Params for :py:class:`LinearRegression` and :py:class:`LinearRegressionModel`.
@ -155,6 +155,8 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable,
LinearRegressionModel...
>>> model.getMaxIter()
5
>>> model.getBlockSize()
1
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> abs(model.predict(test0.head().features) - (-1.0)) < 0.001
True
@ -194,17 +196,18 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable,
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
loss="squaredError", epsilon=1.35):
loss="squaredError", epsilon=1.35, blockSize=1):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
loss="squaredError", epsilon=1.35)
loss="squaredError", epsilon=1.35, blockSize=1)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.LinearRegression", self.uid)
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35)
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, loss="squaredError", epsilon=1.35,
blockSize=1)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -213,12 +216,12 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable,
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto", weightCol=None, aggregationDepth=2,
loss="squaredError", epsilon=1.35):
loss="squaredError", epsilon=1.35, blockSize=1):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
loss="squaredError", epsilon=1.35)
loss="squaredError", epsilon=1.35, blockSize=1)
Sets params for linear regression.
"""
kwargs = self._input_kwargs
@ -294,6 +297,13 @@ class LinearRegression(_JavaRegressor, _LinearRegressionParams, JavaMLWritable,
"""
return self._set(lossType=value)
@since("3.1.0")
def setBlockSize(self, value):
"""
Sets the value of :py:attr:`blockSize`.
"""
return self._set(blockSize=value)
class LinearRegressionModel(_JavaRegressionModel, _LinearRegressionParams, GeneralJavaMLWritable,
JavaMLReadable, HasTrainingSummary):