diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index dfad25d57c..a9c2e23717 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} +import breeze.linalg.{norm => brzNorm} import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** * :: Experimental :: @@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = brzNorm(vector.toBreeze, p) + val norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. // However, for sparse vector, the `index` array will not be changed, // so we can re-use it to save memory. - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) - case sv: BSV[Double] => - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size var i = 0 - while (i < output.data.length) { - output.data(i) /= norm + while (i < size) { + values(i) /= norm i += 1 } - Vectors.fromBreeze(output) + Vectors.dense(values) + case sv: SparseVector => + val values = sv.values.clone() + val nnz = values.size + var i = 0 + while (i < nnz) { + values(i) /= norm + i += 1 + } + Vectors.sparse(sv.size, sv.indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else {