[MINOR][ML] ML cleanup
### What changes were proposed in this pull request? 1, remove used imports and variables; 2, use `.iterator` instead of `.view` to avoid IDEA warnings; 3, remove resolved _TODO_ ### Why are the changes needed? cleanup ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #27600 from zhengruifeng/nits. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
c46c067f39
commit
e086a78706
|
@ -682,7 +682,6 @@ private[spark] object BLAS extends Serializable {
|
||||||
|
|
||||||
val xTemp = xValues(k) * alpha
|
val xTemp = xValues(k) * alpha
|
||||||
while (i < indEnd) {
|
while (i < indEnd) {
|
||||||
val rowIndex = Arows(i)
|
|
||||||
yValues(Arows(i)) += Avals(i) * xTemp
|
yValues(Arows(i)) += Avals(i) * xTemp
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
@ -734,8 +733,7 @@ private[spark] object BLAS extends Serializable {
|
||||||
val indEnd = Acols(colCounterForA + 1)
|
val indEnd = Acols(colCounterForA + 1)
|
||||||
val xVal = xValues(colCounterForA) * alpha
|
val xVal = xValues(colCounterForA) * alpha
|
||||||
while (i < indEnd) {
|
while (i < indEnd) {
|
||||||
val rowIndex = Arows(i)
|
yValues(Arows(i)) += Avals(i) * xVal
|
||||||
yValues(rowIndex) += Avals(i) * xVal
|
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
colCounterForA += 1
|
colCounterForA += 1
|
||||||
|
|
|
@ -139,7 +139,7 @@ class Pipeline @Since("1.4.0") (
|
||||||
val theStages = $(stages)
|
val theStages = $(stages)
|
||||||
// Search for the last estimator.
|
// Search for the last estimator.
|
||||||
var indexOfLastEstimator = -1
|
var indexOfLastEstimator = -1
|
||||||
theStages.view.zipWithIndex.foreach { case (stage, index) =>
|
theStages.iterator.zipWithIndex.foreach { case (stage, index) =>
|
||||||
stage match {
|
stage match {
|
||||||
case _: Estimator[_] =>
|
case _: Estimator[_] =>
|
||||||
indexOfLastEstimator = index
|
indexOfLastEstimator = index
|
||||||
|
@ -148,7 +148,7 @@ class Pipeline @Since("1.4.0") (
|
||||||
}
|
}
|
||||||
var curDataset = dataset
|
var curDataset = dataset
|
||||||
val transformers = ListBuffer.empty[Transformer]
|
val transformers = ListBuffer.empty[Transformer]
|
||||||
theStages.view.zipWithIndex.foreach { case (stage, index) =>
|
theStages.iterator.zipWithIndex.foreach { case (stage, index) =>
|
||||||
if (index <= indexOfLastEstimator) {
|
if (index <= indexOfLastEstimator) {
|
||||||
val transformer = stage match {
|
val transformer = stage match {
|
||||||
case estimator: Estimator[_] =>
|
case estimator: Estimator[_] =>
|
||||||
|
|
|
@ -67,14 +67,12 @@ class AttributeGroup private (
|
||||||
/**
|
/**
|
||||||
* Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined.
|
* Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined.
|
||||||
*/
|
*/
|
||||||
val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) =>
|
val attributes: Option[Array[Attribute]] = attrs.map(_.iterator.zipWithIndex
|
||||||
attr.withIndex(i)
|
.map { case (attr, i) => attr.withIndex(i) }.toArray)
|
||||||
}.toArray)
|
|
||||||
|
|
||||||
private lazy val nameToIndex: Map[String, Int] = {
|
private lazy val nameToIndex: Map[String, Int] = {
|
||||||
attributes.map(_.view.flatMap { attr =>
|
attributes.map(_.iterator.flatMap { attr => attr.name.map(_ -> attr.index.get)}.toMap)
|
||||||
attr.name.map(_ -> attr.index.get)
|
.getOrElse(Map.empty)
|
||||||
}.toMap).getOrElse(Map.empty)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Size of the attribute group. Returns -1 if the size is unknown. */
|
/** Size of the attribute group. Returns -1 if the size is unknown. */
|
||||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector}
|
||||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.col
|
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,8 +17,6 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.classification
|
package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
|
||||||
|
|
||||||
import org.apache.hadoop.fs.Path
|
import org.apache.hadoop.fs.Path
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
|
|
|
@ -256,7 +256,7 @@ class RandomForestClassificationModel private[ml] (
|
||||||
// Classifies using majority votes.
|
// Classifies using majority votes.
|
||||||
// Ignore the tree weights since all are 1.0 for now.
|
// Ignore the tree weights since all are 1.0 for now.
|
||||||
val votes = Array.ofDim[Double](numClasses)
|
val votes = Array.ofDim[Double](numClasses)
|
||||||
_trees.view.foreach { tree =>
|
_trees.foreach { tree =>
|
||||||
val classCounts = tree.rootNode.predictImpl(features).impurityStats.stats
|
val classCounts = tree.rootNode.predictImpl(features).impurityStats.stats
|
||||||
val total = classCounts.sum
|
val total = classCounts.sum
|
||||||
if (total != 0) {
|
if (total != 0) {
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
|
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MetadataUtils, SchemaUtils}
|
||||||
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
|
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
|
||||||
import org.apache.spark.sql.{Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
@ -104,7 +104,9 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
|
||||||
SchemaUtils.checkNumericType(schema, $(weightCol))
|
SchemaUtils.checkNumericType(schema, $(weightCol))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
|
MetadataUtils.getNumFeatures(schema($(rawPredictionCol)))
|
||||||
|
.foreach(n => require(n == 2, s"rawPredictionCol vectors must have length=2, but got $n"))
|
||||||
|
|
||||||
val scoreAndLabelsWithWeights =
|
val scoreAndLabelsWithWeights =
|
||||||
dataset.select(
|
dataset.select(
|
||||||
col($(rawPredictionCol)),
|
col($(rawPredictionCol)),
|
||||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.broadcast.Broadcast
|
import org.apache.spark.broadcast.Broadcast
|
||||||
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
|
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
|
|
|
@ -144,8 +144,6 @@ object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] {
|
||||||
*
|
*
|
||||||
* @param originalMin min value for each original column during fitting
|
* @param originalMin min value for each original column during fitting
|
||||||
* @param originalMax max value for each original column during fitting
|
* @param originalMax max value for each original column during fitting
|
||||||
*
|
|
||||||
* TODO: The transformer does not yet set the metadata in the output column (SPARK-8529).
|
|
||||||
*/
|
*/
|
||||||
@Since("1.5.0")
|
@Since("1.5.0")
|
||||||
class MinMaxScalerModel private[ml] (
|
class MinMaxScalerModel private[ml] (
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.json4s.JsonDSL._
|
||||||
import org.json4s.jackson.JsonMethods._
|
import org.json4s.jackson.JsonMethods._
|
||||||
|
|
||||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
import org.apache.spark.ml.attribute.AttributeGroup
|
|
||||||
import org.apache.spark.ml.feature.RFormula
|
import org.apache.spark.ml.feature.RFormula
|
||||||
import org.apache.spark.ml.r.RWrapperUtils._
|
import org.apache.spark.ml.r.RWrapperUtils._
|
||||||
import org.apache.spark.ml.regression._
|
import org.apache.spark.ml.regression._
|
||||||
|
|
|
@ -1049,7 +1049,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
.join(userFactors)
|
.join(userFactors)
|
||||||
.mapPartitions({ items =>
|
.mapPartitions({ items =>
|
||||||
items.flatMap { case (_, (ids, factors)) =>
|
items.flatMap { case (_, (ids, factors)) =>
|
||||||
ids.view.zip(factors)
|
ids.iterator.zip(factors.iterator)
|
||||||
}
|
}
|
||||||
// Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
|
// Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
|
||||||
// and userFactors.
|
// and userFactors.
|
||||||
|
@ -1061,7 +1061,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
.join(itemFactors)
|
.join(itemFactors)
|
||||||
.mapPartitions({ items =>
|
.mapPartitions({ items =>
|
||||||
items.flatMap { case (_, (ids, factors)) =>
|
items.flatMap { case (_, (ids, factors)) =>
|
||||||
ids.view.zip(factors)
|
ids.iterator.zip(factors.iterator)
|
||||||
}
|
}
|
||||||
}, preservesPartitioning = true)
|
}, preservesPartitioning = true)
|
||||||
.setName("itemFactors")
|
.setName("itemFactors")
|
||||||
|
@ -1376,7 +1376,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
Iterator.empty
|
Iterator.empty
|
||||||
}
|
}
|
||||||
} ++ {
|
} ++ {
|
||||||
builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
|
builders.iterator.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
|
||||||
val srcBlockId = idx % srcPart.numPartitions
|
val srcBlockId = idx % srcPart.numPartitions
|
||||||
val dstBlockId = idx / srcPart.numPartitions
|
val dstBlockId = idx / srcPart.numPartitions
|
||||||
((srcBlockId, dstBlockId), block.build())
|
((srcBlockId, dstBlockId), block.build())
|
||||||
|
@ -1695,7 +1695,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
|
||||||
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
|
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
|
||||||
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
|
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
|
||||||
case (srcBlockId, (srcOutBlock, srcFactors)) =>
|
case (srcBlockId, (srcOutBlock, srcFactors)) =>
|
||||||
srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
|
srcOutBlock.iterator.zipWithIndex.map { case (activeIndices, dstBlockId) =>
|
||||||
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
|
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,6 @@ import org.apache.spark.mllib.optimization.{Gradient, GradientDescent, SquaredL2
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{Dataset, Row}
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
import org.apache.spark.sql.functions.col
|
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,9 +26,8 @@ import org.apache.spark.ml.image.ImageSchema
|
||||||
import org.apache.spark.sql.SparkSession
|
import org.apache.spark.sql.SparkSession
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeRow}
|
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||||
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
|
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
|
||||||
import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, OutputWriterFactory, PartitionedFile}
|
|
||||||
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
|
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
import org.apache.spark.util.SerializableConfiguration
|
import org.apache.spark.util.SerializableConfiguration
|
||||||
|
|
|
@ -825,7 +825,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
val validFeatureSplits =
|
val validFeatureSplits =
|
||||||
Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
|
Iterator.range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
|
||||||
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
|
featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
|
||||||
.getOrElse((featureIndexIdx, featureIndexIdx))
|
.getOrElse((featureIndexIdx, featureIndexIdx))
|
||||||
}.withFilter { case (_, featureIndex) =>
|
}.withFilter { case (_, featureIndex) =>
|
||||||
|
|
|
@ -95,7 +95,7 @@ class StreamingKMeansModel @Since("1.2.0") (
|
||||||
val discount = timeUnit match {
|
val discount = timeUnit match {
|
||||||
case StreamingKMeans.BATCHES => decayFactor
|
case StreamingKMeans.BATCHES => decayFactor
|
||||||
case StreamingKMeans.POINTS =>
|
case StreamingKMeans.POINTS =>
|
||||||
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
|
val numNewPoints = pointStats.iterator.map { case (_, (_, n)) =>
|
||||||
n
|
n
|
||||||
}.sum
|
}.sum
|
||||||
math.pow(decayFactor, numNewPoints)
|
math.pow(decayFactor, numNewPoints)
|
||||||
|
@ -125,9 +125,8 @@ class StreamingKMeansModel @Since("1.2.0") (
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether the smallest cluster is dying. If so, split the largest cluster.
|
// Check whether the smallest cluster is dying. If so, split the largest cluster.
|
||||||
val weightsWithIndex = clusterWeights.view.zipWithIndex
|
val (maxWeight, largest) = clusterWeights.iterator.zipWithIndex.maxBy(_._1)
|
||||||
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
|
val (minWeight, smallest) = clusterWeights.iterator.zipWithIndex.minBy(_._1)
|
||||||
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
|
|
||||||
if (minWeight < 1e-8 * maxWeight) {
|
if (minWeight < 1e-8 * maxWeight) {
|
||||||
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
|
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
|
||||||
val weight = (maxWeight + minWeight) / 2.0
|
val weight = (maxWeight + minWeight) / 2.0
|
||||||
|
|
|
@ -74,8 +74,9 @@ class ChiSqSelectorModel @Since("1.3.0") (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def compressSparse(indices: Array[Int],
|
private[spark] def compressSparse(
|
||||||
values: Array[Double]): (Array[Int], Array[Double]) = {
|
indices: Array[Int],
|
||||||
|
values: Array[Double]): (Array[Int], Array[Double]) = {
|
||||||
val newValues = new ArrayBuilder.ofDouble
|
val newValues = new ArrayBuilder.ofDouble
|
||||||
val newIndices = new ArrayBuilder.ofInt
|
val newIndices = new ArrayBuilder.ofInt
|
||||||
var i = 0
|
var i = 0
|
||||||
|
|
|
@ -64,8 +64,9 @@ class ElementwiseProduct @Since("1.4.0") (
|
||||||
newValues
|
newValues
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSparse(indices: Array[Int],
|
private[spark] def transformSparse(
|
||||||
values: Array[Double]): (Array[Int], Array[Double]) = {
|
indices: Array[Int],
|
||||||
|
values: Array[Double]): (Array[Int], Array[Double]) = {
|
||||||
val newValues = values.clone()
|
val newValues = values.clone()
|
||||||
val dim = newValues.length
|
val dim = newValues.length
|
||||||
var i = 0
|
var i = 0
|
||||||
|
|
|
@ -226,8 +226,9 @@ private[spark] object IDFModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformDense(idf: Vector,
|
private[spark] def transformDense(
|
||||||
values: Array[Double]): Array[Double] = {
|
idf: Vector,
|
||||||
|
values: Array[Double]): Array[Double] = {
|
||||||
val n = values.length
|
val n = values.length
|
||||||
val newValues = new Array[Double](n)
|
val newValues = new Array[Double](n)
|
||||||
var j = 0
|
var j = 0
|
||||||
|
@ -238,9 +239,10 @@ private[spark] object IDFModel {
|
||||||
newValues
|
newValues
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def transformSparse(idf: Vector,
|
private[spark] def transformSparse(
|
||||||
indices: Array[Int],
|
idf: Vector,
|
||||||
values: Array[Double]): (Array[Int], Array[Double]) = {
|
indices: Array[Int],
|
||||||
|
values: Array[Double]): (Array[Int], Array[Double]) = {
|
||||||
val nnz = indices.length
|
val nnz = indices.length
|
||||||
val newValues = new Array[Double](nnz)
|
val newValues = new Array[Double](nnz)
|
||||||
var k = 0
|
var k = 0
|
||||||
|
|
|
@ -663,7 +663,6 @@ private[spark] object BLAS extends Serializable with Logging {
|
||||||
|
|
||||||
val xTemp = xValues(k) * alpha
|
val xTemp = xValues(k) * alpha
|
||||||
while (i < indEnd) {
|
while (i < indEnd) {
|
||||||
val rowIndex = Arows(i)
|
|
||||||
yValues(Arows(i)) += Avals(i) * xTemp
|
yValues(Arows(i)) += Avals(i) * xTemp
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
@ -715,8 +714,7 @@ private[spark] object BLAS extends Serializable with Logging {
|
||||||
val indEnd = Acols(colCounterForA + 1)
|
val indEnd = Acols(colCounterForA + 1)
|
||||||
val xVal = xValues(colCounterForA) * alpha
|
val xVal = xValues(colCounterForA) * alpha
|
||||||
while (i < indEnd) {
|
while (i < indEnd) {
|
||||||
val rowIndex = Arows(i)
|
yValues(Arows(i)) += Avals(i) * xVal
|
||||||
yValues(rowIndex) += Avals(i) * xVal
|
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
colCounterForA += 1
|
colCounterForA += 1
|
||||||
|
|
|
@ -292,7 +292,7 @@ object GradientDescent extends Logging {
|
||||||
miniBatchFraction: Double,
|
miniBatchFraction: Double,
|
||||||
initialWeights: Vector): (Vector, Array[Double]) =
|
initialWeights: Vector): (Vector, Array[Double]) =
|
||||||
GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
|
GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
|
||||||
regParam, miniBatchFraction, initialWeights, 0.001)
|
regParam, miniBatchFraction, initialWeights, 0.001)
|
||||||
|
|
||||||
|
|
||||||
private def isConverged(
|
private def isConverged(
|
||||||
|
|
|
@ -46,9 +46,7 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
|
||||||
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
|
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
|
||||||
// ((columnIndex, value), rowUid)
|
// ((columnIndex, value), rowUid)
|
||||||
val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
|
val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
|
||||||
vec.toArray.view.zipWithIndex.map { case (v, j) =>
|
vec.iterator.map(t => (t, uid))
|
||||||
((j, v), uid)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// global sort by (columnIndex, value)
|
// global sort by (columnIndex, value)
|
||||||
val sorted = colBased.sortByKey()
|
val sorted = colBased.sortByKey()
|
||||||
|
|
|
@ -25,7 +25,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
object Entropy extends Impurity {
|
object Entropy extends Impurity {
|
||||||
|
|
||||||
private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
|
private val _log2 = scala.math.log(2)
|
||||||
|
|
||||||
|
private[tree] def log2(x: Double) = scala.math.log(x) / _log2
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
|
|
|
@ -292,7 +292,7 @@ private[tree] sealed class TreeEnsembleModel(
|
||||||
*/
|
*/
|
||||||
private def predictByVoting(features: Vector): Double = {
|
private def predictByVoting(features: Vector): Double = {
|
||||||
val votes = mutable.Map.empty[Int, Double]
|
val votes = mutable.Map.empty[Int, Double]
|
||||||
trees.view.zip(treeWeights).foreach { case (tree, weight) =>
|
trees.iterator.zip(treeWeights.iterator).foreach { case (tree, weight) =>
|
||||||
val prediction = tree.predict(features).toInt
|
val prediction = tree.predict(features).toInt
|
||||||
votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
|
votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,6 @@ package org.apache.spark.mllib.util
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
import scala.util.Random
|
import scala.util.Random
|
||||||
|
|
||||||
import com.github.fommil.netlib.BLAS.{getInstance => blas}
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.annotation.{DeveloperApi, Since}
|
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
import org.apache.spark.mllib.linalg.{BLAS, Vectors}
|
import org.apache.spark.mllib.linalg.{BLAS, Vectors}
|
||||||
|
|
Loading…
Reference in a new issue