diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index f98b0b536d..b9621530ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint]) : M = { val nfeatures: Int = input.first().features.length - val initialWeights = Array.fill(nfeatures)(1.0) + val initialWeights = new Array[Double](nfeatures) run(input, initialWeights) } @@ -134,15 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - // Add a extra variable consisting of all 1.0's for the intercept. + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - Array(1.0, initialWeights:_*) + initialWeights.+:(1.0) } else { initialWeights }