[SPARK-21481][ML][FOLLOWUP] HashingTF Cleanup

## What changes were proposed in this pull request?
some cleanup and tiny optimization
1, since the `transformImpl` method in the .mllib side is no longer used in the .ml side, the scope should be limited;
2, in the `hashUDF`, val `numOfFeatures` is never used;
3, in the udf, it is inefficient to involve param getter (`$(numFeatures)`/`$(binary)`) directly or via method `indexOf` ((`$(numFeatures)`) . instead, the getter should be called outside of the udf;

## How was this patch tested?
existing suites

Closes #25324 from zhengruifeng/hashingtf_cleanup.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
zhengruifeng 2019-08-09 10:04:39 -05:00 committed by Sean Owen
parent cbad616d4c
commit 8b08e14de7
3 changed files with 6 additions and 6 deletions

View file

@ -100,19 +100,20 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val localNumFeatures = $(numFeatures)
val localBinary = $(binary)
val hashUDF = udf { terms: Seq[_] =>
val numOfFeatures = $(numFeatures)
val isBinary = $(binary)
val termFrequencies = mutable.HashMap.empty[Int, Double].withDefaultValue(0.0)
terms.foreach { term =>
val i = indexOf(term)
if (isBinary) {
if (localBinary) {
termFrequencies(i) = 1.0
} else {
termFrequencies(i) += 1.0
}
}
Vectors.sparse($(numFeatures), termFrequencies.toSeq)
Vectors.sparse(localNumFeatures, termFrequencies.toSeq)
}
dataset.withColumn($(outputCol), hashUDF(col($(inputCol))),

View file

@ -98,7 +98,7 @@ class HashingTF(val numFeatures: Int) extends Serializable {
Vectors.sparse(numFeatures, seq)
}
private[spark] def transformImpl(document: Iterable[_]): Seq[(Int, Double)] = {
private def transformImpl(document: Iterable[_]): Seq[(Int, Double)] = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
val hashFunc: Any => Int = getHashFunction

View file

@ -77,7 +77,6 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
}
test("indexOf method") {
val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
val n = 100
val hashingTF = new HashingTF()
.setInputCol("words")