[SPARK-9156][SQL] codegen StringSplit
Jira: https://issues.apache.org/jira/browse/SPARK-9156 Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7547 from tarekauel/SPARK-9156 and squashes the following commits: 0be2700 [Tarek Auel] [SPARK-9156][SQL] indention fix b860eaf [Tarek Auel] [SPARK-9156][SQL] codegen StringSplit 5ad6a1f [Tarek Auel] [SPARK-9156] codegen StringSplit
This commit is contained in:
parent
047ccc8c9a
commit
6853ac7c8c
|
@ -615,7 +615,7 @@ case class StringSpace(child: Expression)
|
|||
* Splits str around pat (pattern is a regular expression).
|
||||
*/
|
||||
case class StringSplit(str: Expression, pattern: Expression)
|
||||
extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback {
|
||||
extends BinaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def left: Expression = str
|
||||
override def right: Expression = pattern
|
||||
|
@ -623,9 +623,13 @@ case class StringSplit(str: Expression, pattern: Expression)
|
|||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
||||
|
||||
override def nullSafeEval(string: Any, regex: Any): Any = {
|
||||
val splits =
|
||||
string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1)
|
||||
splits.toSeq.map(UTF8String.fromString)
|
||||
string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
nullSafeCodeGen(ctx, ev, (str, pattern) =>
|
||||
s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer(
|
||||
java.util.Arrays.asList($str.split($pattern, -1)));""")
|
||||
}
|
||||
|
||||
override def prettyName: String = "split"
|
||||
|
|
|
@ -487,6 +487,15 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
|
|||
return fromBytes(result);
|
||||
}
|
||||
|
||||
public UTF8String[] split(UTF8String pattern, int limit) {
|
||||
String[] splits = toString().split(pattern.toString(), limit);
|
||||
UTF8String[] res = new UTF8String[splits.length];
|
||||
for (int i = 0; i < res.length; i++) {
|
||||
res[i] = fromString(splits[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
try {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.unsafe.types;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -271,6 +272,16 @@ public class UTF8StringSuite {
|
|||
fromString("数据砖头").rpad(12, fromString("孙行者")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void split() {
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
|
||||
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")}));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
|
||||
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
|
||||
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void levenshteinDistance() {
|
||||
assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0);
|
||||
|
|
Loading…
Reference in a new issue