[SPARK-21481][ML][FOLLOWUP] HashingTF Cleanup
## What changes were proposed in this pull request? some cleanup and tiny optimization 1, since the `transformImpl` method in the .mllib side is no longer used in the .ml side, the scope should be limited; 2, in the `hashUDF`, val `numOfFeatures` is never used; 3, in the udf, it is inefficient to involve param getter (`$(numFeatures)`/`$(binary)`) directly or via method `indexOf` ((`$(numFeatures)`) . instead, the getter should be called outside of the udf; ## How was this patch tested? existing suites Closes #25324 from zhengruifeng/hashingtf_cleanup. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
cbad616d4c
commit
8b08e14de7
|
@ -100,19 +100,20 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
|
|||
@Since("2.0.0")
|
||||
override def transform(dataset: Dataset[_]): DataFrame = {
|
||||
val outputSchema = transformSchema(dataset.schema)
|
||||
val localNumFeatures = $(numFeatures)
|
||||
val localBinary = $(binary)
|
||||
|
||||
val hashUDF = udf { terms: Seq[_] =>
|
||||
val numOfFeatures = $(numFeatures)
|
||||
val isBinary = $(binary)
|
||||
val termFrequencies = mutable.HashMap.empty[Int, Double].withDefaultValue(0.0)
|
||||
terms.foreach { term =>
|
||||
val i = indexOf(term)
|
||||
if (isBinary) {
|
||||
if (localBinary) {
|
||||
termFrequencies(i) = 1.0
|
||||
} else {
|
||||
termFrequencies(i) += 1.0
|
||||
}
|
||||
}
|
||||
Vectors.sparse($(numFeatures), termFrequencies.toSeq)
|
||||
Vectors.sparse(localNumFeatures, termFrequencies.toSeq)
|
||||
}
|
||||
|
||||
dataset.withColumn($(outputCol), hashUDF(col($(inputCol))),
|
||||
|
|
|
@ -98,7 +98,7 @@ class HashingTF(val numFeatures: Int) extends Serializable {
|
|||
Vectors.sparse(numFeatures, seq)
|
||||
}
|
||||
|
||||
private[spark] def transformImpl(document: Iterable[_]): Seq[(Int, Double)] = {
|
||||
private def transformImpl(document: Iterable[_]): Seq[(Int, Double)] = {
|
||||
val termFrequencies = mutable.HashMap.empty[Int, Double]
|
||||
val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
|
||||
val hashFunc: Any => Int = getHashFunction
|
||||
|
|
|
@ -77,7 +77,6 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
|
|||
}
|
||||
|
||||
test("indexOf method") {
|
||||
val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
|
||||
val n = 100
|
||||
val hashingTF = new HashingTF()
|
||||
.setInputCol("words")
|
||||
|
|
Loading…
Reference in a new issue