[SPARK-15610][ML] update error message for k in pca
## What changes were proposed in this pull request? Fix the wrong bound of `k` in `PCA` `require(k <= sources.first().size, ...` -> `require(k < sources.first().size` BTW, remove unused import in `ml.ElementwiseProduct` ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13356 from zhengruifeng/fix_pca.
This commit is contained in:
parent
88c9c467a3
commit
9893dc9757
|
@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
|
|||
import org.apache.spark.ml.param.Param
|
||||
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
|
||||
import org.apache.spark.mllib.feature
|
||||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
|
|
|
@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
|
|||
*/
|
||||
@Since("1.4.0")
|
||||
def fit(sources: RDD[Vector]): PCAModel = {
|
||||
require(k <= sources.first().size,
|
||||
s"source vector size is ${sources.first().size} must be greater than k=$k")
|
||||
val numFeatures = sources.first().size
|
||||
require(k <= numFeatures,
|
||||
s"source vector size $numFeatures must be no less than k=$k")
|
||||
|
||||
val mat = new RowMatrix(sources)
|
||||
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
|
||||
|
@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
|
|||
case m =>
|
||||
throw new IllegalArgumentException("Unsupported matrix format. Expected " +
|
||||
s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
|
||||
|
||||
}
|
||||
val denseExplainedVariance = explainedVariance match {
|
||||
case dv: DenseVector =>
|
||||
|
|
Loading…
Reference in a new issue