[SPARK-16281][SQL] Implement parse_url SQL function
## What changes were proposed in this pull request? This PR adds parse_url SQL functions in order to remove Hive fallback. A new implementation of #13999 ## How was this patch tested? Pass the exist tests including new testcases. Author: wujian <jan.chou.wu@gmail.com> Closes #14008 from janplus/SPARK-16281.
This commit is contained in:
parent
142df4834b
commit
f5fef69143
|
@ -288,6 +288,7 @@ object FunctionRegistry {
|
|||
expression[StringLPad]("lpad"),
|
||||
expression[StringTrimLeft]("ltrim"),
|
||||
expression[JsonTuple]("json_tuple"),
|
||||
expression[ParseUrl]("parse_url"),
|
||||
expression[FormatString]("printf"),
|
||||
expression[RegExpExtract]("regexp_extract"),
|
||||
expression[RegExpReplace]("regexp_replace"),
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.net.{MalformedURLException, URL}
|
||||
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
|
||||
import java.util.{HashMap, Locale, Map => JMap}
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
@ -654,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
|
|||
override def prettyName: String = "rpad"
|
||||
}
|
||||
|
||||
object ParseUrl {
|
||||
private val HOST = UTF8String.fromString("HOST")
|
||||
private val PATH = UTF8String.fromString("PATH")
|
||||
private val QUERY = UTF8String.fromString("QUERY")
|
||||
private val REF = UTF8String.fromString("REF")
|
||||
private val PROTOCOL = UTF8String.fromString("PROTOCOL")
|
||||
private val FILE = UTF8String.fromString("FILE")
|
||||
private val AUTHORITY = UTF8String.fromString("AUTHORITY")
|
||||
private val USERINFO = UTF8String.fromString("USERINFO")
|
||||
private val REGEXPREFIX = "(&|^)"
|
||||
private val REGEXSUBFIX = "=([^&]*)"
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a part from a URL
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
|
||||
extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO.
|
||||
Key specifies which query to extract.
|
||||
Examples:
|
||||
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST')
|
||||
'spark.apache.org'
|
||||
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY')
|
||||
'query=1'
|
||||
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query')
|
||||
'1'""")
|
||||
case class ParseUrl(children: Seq[Expression])
|
||||
extends Expression with ExpectsInputTypes with CodegenFallback {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
|
||||
override def dataType: DataType = StringType
|
||||
override def prettyName: String = "parse_url"
|
||||
|
||||
// If the url is a constant, cache the URL object so that we don't need to convert url
|
||||
// from UTF8String to String to URL for every row.
|
||||
@transient private lazy val cachedUrl = children(0) match {
|
||||
case Literal(url: UTF8String, _) if url ne null => getUrl(url)
|
||||
case _ => null
|
||||
}
|
||||
|
||||
// If the key is a constant, cache the Pattern object so that we don't need to convert key
|
||||
// from UTF8String to String to StringBuilder to String to Pattern for every row.
|
||||
@transient private lazy val cachedPattern = children(2) match {
|
||||
case Literal(key: UTF8String, _) if key ne null => getPattern(key)
|
||||
case _ => null
|
||||
}
|
||||
|
||||
// If the partToExtract is a constant, cache the Extract part function so that we don't need
|
||||
// to check the partToExtract for every row.
|
||||
@transient private lazy val cachedExtractPartFunc = children(1) match {
|
||||
case Literal(part: UTF8String, _) => getExtractPartFunc(part)
|
||||
case _ => null
|
||||
}
|
||||
|
||||
import ParseUrl._
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.size > 3 || children.size < 2) {
|
||||
TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
|
||||
} else {
|
||||
super[ExpectsInputTypes].checkInputDataTypes()
|
||||
}
|
||||
}
|
||||
|
||||
private def getPattern(key: UTF8String): Pattern = {
|
||||
Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
|
||||
}
|
||||
|
||||
private def getUrl(url: UTF8String): URL = {
|
||||
try {
|
||||
new URL(url.toString)
|
||||
} catch {
|
||||
case e: MalformedURLException => null
|
||||
}
|
||||
}
|
||||
|
||||
private def getExtractPartFunc(partToExtract: UTF8String): URL => String = {
|
||||
partToExtract match {
|
||||
case HOST => _.getHost
|
||||
case PATH => _.getPath
|
||||
case QUERY => _.getQuery
|
||||
case REF => _.getRef
|
||||
case PROTOCOL => _.getProtocol
|
||||
case FILE => _.getFile
|
||||
case AUTHORITY => _.getAuthority
|
||||
case USERINFO => _.getUserInfo
|
||||
case _ => (url: URL) => null
|
||||
}
|
||||
}
|
||||
|
||||
private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
|
||||
val m = pattern.matcher(query.toString)
|
||||
if (m.find()) {
|
||||
UTF8String.fromString(m.group(2))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = {
|
||||
if (cachedExtractPartFunc ne null) {
|
||||
UTF8String.fromString(cachedExtractPartFunc.apply(url))
|
||||
} else {
|
||||
UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
|
||||
}
|
||||
}
|
||||
|
||||
private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
|
||||
if (cachedUrl ne null) {
|
||||
extractFromUrl(cachedUrl, partToExtract)
|
||||
} else {
|
||||
val currentUrl = getUrl(url)
|
||||
if (currentUrl ne null) {
|
||||
extractFromUrl(currentUrl, partToExtract)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
|
||||
if (evaluated.contains(null)) return null
|
||||
if (evaluated.size == 2) {
|
||||
parseUrlWithoutKey(evaluated(0), evaluated(1))
|
||||
} else {
|
||||
// 3-arg, i.e. QUERY with key
|
||||
assert(evaluated.size == 3)
|
||||
if (evaluated(1) != QUERY) {
|
||||
return null
|
||||
}
|
||||
|
||||
val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
|
||||
if (query eq null) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (cachedPattern ne null) {
|
||||
extractValueFromQuery(query, cachedPattern)
|
||||
} else {
|
||||
extractValueFromQuery(query, getPattern(evaluated(2)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the input formatted according do printf-style format strings
|
||||
*/
|
||||
|
|
|
@ -726,6 +726,57 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
|
||||
}
|
||||
|
||||
test("ParseUrl") {
|
||||
def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
|
||||
checkEvaluation(
|
||||
ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
|
||||
}
|
||||
def checkParseUrlWithKey(
|
||||
expected: String,
|
||||
urlStr: String,
|
||||
partToExtract: String,
|
||||
key: String): Unit = {
|
||||
checkEvaluation(
|
||||
ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected)
|
||||
}
|
||||
|
||||
checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
|
||||
checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH")
|
||||
checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY")
|
||||
checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF")
|
||||
checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL")
|
||||
checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE")
|
||||
checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY")
|
||||
checkParseUrl("userinfo", "http://userinfo@spark.apache.org/path?query=1", "USERINFO")
|
||||
checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query")
|
||||
|
||||
// Null checking
|
||||
checkParseUrl(null, null, "HOST")
|
||||
checkParseUrl(null, "http://spark.apache.org/path?query=1", null)
|
||||
checkParseUrl(null, null, null)
|
||||
checkParseUrl(null, "test", "HOST")
|
||||
checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO")
|
||||
checkParseUrl(null, "http://spark.apache.org/path?query=1", "USERINFO")
|
||||
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query")
|
||||
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer")
|
||||
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null)
|
||||
checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "")
|
||||
|
||||
// exceptional cases
|
||||
intercept[java.util.regex.PatternSyntaxException] {
|
||||
evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
|
||||
Literal("QUERY"), Literal("???"))))
|
||||
}
|
||||
|
||||
// arguments checking
|
||||
assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
|
||||
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4")))
|
||||
.checkInputDataTypes().isFailure)
|
||||
assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
|
||||
assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
|
||||
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)
|
||||
}
|
||||
|
||||
test("Sentences") {
|
||||
val nullString = Literal.create(null, StringType)
|
||||
checkEvaluation(Sentences(nullString, nullString, nullString), null)
|
||||
|
|
|
@ -228,6 +228,21 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
Row("???hi", "hi???", "h", "h"))
|
||||
}
|
||||
|
||||
test("string parse_url function") {
|
||||
val df = Seq[String](("http://userinfo@spark.apache.org/path?query=1#Ref"))
|
||||
.toDF("url")
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr(
|
||||
"parse_url(url, 'HOST')", "parse_url(url, 'PATH')",
|
||||
"parse_url(url, 'QUERY')", "parse_url(url, 'REF')",
|
||||
"parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')",
|
||||
"parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')",
|
||||
"parse_url(url, 'QUERY', 'query')"),
|
||||
Row("spark.apache.org", "/path", "query=1", "Ref",
|
||||
"http", "/path?query=1", "userinfo@spark.apache.org", "userinfo", "1"))
|
||||
}
|
||||
|
||||
test("string repeat function") {
|
||||
val df = Seq(("hi", 2)).toDF("a", "b")
|
||||
|
||||
|
|
|
@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
|
|||
// str_to_map, windowingtablefunction.
|
||||
private val hiveFunctions = Seq(
|
||||
"hash", "java_method", "histogram_numeric",
|
||||
"parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
|
||||
"percentile", "percentile_approx", "reflect", "str_to_map",
|
||||
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
|
||||
"xpath_number", "xpath_short", "xpath_string"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue