[SPARK-16281][SQL] Implement parse_url SQL function

## What changes were proposed in this pull request?

This PR adds parse_url SQL functions in order to remove Hive fallback.

A new implementation of #13999

## How was this patch tested?

Pass the exist tests including new testcases.

Author: wujian <jan.chou.wu@gmail.com>

Closes #14008 from janplus/SPARK-16281.
This commit is contained in:
wujian 2016-07-08 14:38:05 -07:00 committed by Reynold Xin
parent 142df4834b
commit f5fef69143
5 changed files with 218 additions and 1 deletions

View file

@ -288,6 +288,7 @@ object FunctionRegistry {
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
expression[ParseUrl]("parse_url"),
expression[FormatString]("printf"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),

View file

@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
import java.net.{MalformedURLException, URL}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
import java.util.regex.Pattern
import scala.collection.mutable.ArrayBuffer
@ -654,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
override def prettyName: String = "rpad"
}
object ParseUrl {
private val HOST = UTF8String.fromString("HOST")
private val PATH = UTF8String.fromString("PATH")
private val QUERY = UTF8String.fromString("QUERY")
private val REF = UTF8String.fromString("REF")
private val PROTOCOL = UTF8String.fromString("PROTOCOL")
private val FILE = UTF8String.fromString("FILE")
private val AUTHORITY = UTF8String.fromString("AUTHORITY")
private val USERINFO = UTF8String.fromString("USERINFO")
private val REGEXPREFIX = "(&|^)"
private val REGEXSUBFIX = "=([^&]*)"
}
/**
* Extracts a part from a URL
*/
@ExpressionDescription(
usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO.
Key specifies which query to extract.
Examples:
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST')
'spark.apache.org'
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY')
'query=1'
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query')
'1'""")
case class ParseUrl(children: Seq[Expression])
extends Expression with ExpectsInputTypes with CodegenFallback {
override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType
override def prettyName: String = "parse_url"
// If the url is a constant, cache the URL object so that we don't need to convert url
// from UTF8String to String to URL for every row.
@transient private lazy val cachedUrl = children(0) match {
case Literal(url: UTF8String, _) if url ne null => getUrl(url)
case _ => null
}
// If the key is a constant, cache the Pattern object so that we don't need to convert key
// from UTF8String to String to StringBuilder to String to Pattern for every row.
@transient private lazy val cachedPattern = children(2) match {
case Literal(key: UTF8String, _) if key ne null => getPattern(key)
case _ => null
}
// If the partToExtract is a constant, cache the Extract part function so that we don't need
// to check the partToExtract for every row.
@transient private lazy val cachedExtractPartFunc = children(1) match {
case Literal(part: UTF8String, _) => getExtractPartFunc(part)
case _ => null
}
import ParseUrl._
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
} else {
super[ExpectsInputTypes].checkInputDataTypes()
}
}
private def getPattern(key: UTF8String): Pattern = {
Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
}
private def getUrl(url: UTF8String): URL = {
try {
new URL(url.toString)
} catch {
case e: MalformedURLException => null
}
}
private def getExtractPartFunc(partToExtract: UTF8String): URL => String = {
partToExtract match {
case HOST => _.getHost
case PATH => _.getPath
case QUERY => _.getQuery
case REF => _.getRef
case PROTOCOL => _.getProtocol
case FILE => _.getFile
case AUTHORITY => _.getAuthority
case USERINFO => _.getUserInfo
case _ => (url: URL) => null
}
}
private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
val m = pattern.matcher(query.toString)
if (m.find()) {
UTF8String.fromString(m.group(2))
} else {
null
}
}
private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = {
if (cachedExtractPartFunc ne null) {
UTF8String.fromString(cachedExtractPartFunc.apply(url))
} else {
UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
}
}
private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
if (cachedUrl ne null) {
extractFromUrl(cachedUrl, partToExtract)
} else {
val currentUrl = getUrl(url)
if (currentUrl ne null) {
extractFromUrl(currentUrl, partToExtract)
} else {
null
}
}
}
override def eval(input: InternalRow): Any = {
val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
if (evaluated.contains(null)) return null
if (evaluated.size == 2) {
parseUrlWithoutKey(evaluated(0), evaluated(1))
} else {
// 3-arg, i.e. QUERY with key
assert(evaluated.size == 3)
if (evaluated(1) != QUERY) {
return null
}
val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
if (query eq null) {
return null
}
if (cachedPattern ne null) {
extractValueFromQuery(query, cachedPattern)
} else {
extractValueFromQuery(query, getPattern(evaluated(2)))
}
}
}
}
/**
* Returns the input formatted according do printf-style format strings
*/

View file

@ -726,6 +726,57 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}
test("ParseUrl") {
def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
checkEvaluation(
ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
}
def checkParseUrlWithKey(
expected: String,
urlStr: String,
partToExtract: String,
key: String): Unit = {
checkEvaluation(
ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected)
}
checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH")
checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY")
checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF")
checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL")
checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE")
checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY")
checkParseUrl("userinfo", "http://userinfo@spark.apache.org/path?query=1", "USERINFO")
checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query")
// Null checking
checkParseUrl(null, null, "HOST")
checkParseUrl(null, "http://spark.apache.org/path?query=1", null)
checkParseUrl(null, null, null)
checkParseUrl(null, "test", "HOST")
checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO")
checkParseUrl(null, "http://spark.apache.org/path?query=1", "USERINFO")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer")
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null)
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "")
// exceptional cases
intercept[java.util.regex.PatternSyntaxException] {
evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
Literal("QUERY"), Literal("???"))))
}
// arguments checking
assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4")))
.checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)
}
test("Sentences") {
val nullString = Literal.create(null, StringType)
checkEvaluation(Sentences(nullString, nullString, nullString), null)

View file

@ -228,6 +228,21 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("???hi", "hi???", "h", "h"))
}
test("string parse_url function") {
val df = Seq[String](("http://userinfo@spark.apache.org/path?query=1#Ref"))
.toDF("url")
checkAnswer(
df.selectExpr(
"parse_url(url, 'HOST')", "parse_url(url, 'PATH')",
"parse_url(url, 'QUERY')", "parse_url(url, 'REF')",
"parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')",
"parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')",
"parse_url(url, 'QUERY', 'query')"),
Row("spark.apache.org", "/path", "query=1", "Ref",
"http", "/path?query=1", "userinfo@spark.apache.org", "userinfo", "1"))
}
test("string repeat function") {
val df = Seq(("hi", 2)).toDF("a", "b")

View file

@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
"parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
"percentile", "percentile_approx", "reflect", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string"
)