[SPARK-27965][SQL] Add extractors for v2 catalog transforms.

## What changes were proposed in this pull request?

Add extractors for v2 catalog transforms.

These extractors are used to match transforms that are equivalent to Spark's internal case classes. This makes it easier to work with v2 transforms.

## How was this patch tested?

Added test suite for the new extractors.

Closes #24812 from rdblue/SPARK-27965-add-transform-extractors.

Authored-by: Ryan Blue <blue@apache.org>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
Ryan Blue 2019-06-07 00:20:36 -07:00 committed by gatorsmile
parent eee3467b1e
commit b30655bdef
2 changed files with 239 additions and 0 deletions

View file

@ -94,6 +94,17 @@ private[sql] final case class BucketTransform(
override def toString: String = describe
}
private[sql] object BucketTransform {
def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match {
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
Ref(seq: Seq[String]))) =>
Some((value, FieldReference(seq)))
case _ =>
None
}
}
private[sql] final case class ApplyTransform(
name: String,
args: Seq[Expression]) extends Transform {
@ -111,32 +122,104 @@ private[sql] final case class ApplyTransform(
override def toString: String = describe
}
/**
* Convenience extractor for any Literal.
*/
private object Lit {
def unapply[T](literal: Literal[T]): Some[(T, DataType)] = {
Some((literal.value, literal.dataType))
}
}
/**
* Convenience extractor for any NamedReference.
*/
private object Ref {
def unapply(named: NamedReference): Some[Seq[String]] = {
Some(named.fieldNames)
}
}
/**
* Convenience extractor for any Transform.
*/
private object NamedTransform {
def unapply(transform: Transform): Some[(String, Seq[Expression])] = {
Some((transform.name, transform.arguments))
}
}
private[sql] final case class IdentityTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "identity"
override def describe: String = ref.describe
}
private[sql] object IdentityTransform {
def unapply(transform: Transform): Option[FieldReference] = transform match {
case NamedTransform("identity", Seq(Ref(parts))) =>
Some(FieldReference(parts))
case _ =>
None
}
}
private[sql] final case class YearsTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "years"
}
private[sql] object YearsTransform {
def unapply(transform: Transform): Option[FieldReference] = transform match {
case NamedTransform("years", Seq(Ref(parts))) =>
Some(FieldReference(parts))
case _ =>
None
}
}
private[sql] final case class MonthsTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "months"
}
private[sql] object MonthsTransform {
def unapply(transform: Transform): Option[FieldReference] = transform match {
case NamedTransform("months", Seq(Ref(parts))) =>
Some(FieldReference(parts))
case _ =>
None
}
}
private[sql] final case class DaysTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "days"
}
private[sql] object DaysTransform {
def unapply(transform: Transform): Option[FieldReference] = transform match {
case NamedTransform("days", Seq(Ref(parts))) =>
Some(FieldReference(parts))
case _ =>
None
}
}
private[sql] final case class HoursTransform(
ref: NamedReference) extends SingleColumnTransform(ref) {
override val name: String = "hours"
}
private[sql] object HoursTransform {
def unapply(transform: Transform): Option[FieldReference] = transform match {
case NamedTransform("hours", Seq(Ref(parts))) =>
Some(FieldReference(parts))
case _ =>
None
}
}
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
override def describe: String = {
if (dataType.isInstanceOf[StringType]) {

View file

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalog.v2.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.types.DataType
class TransformExtractorSuite extends SparkFunSuite {
/**
* Creates a Literal using an anonymous class.
*/
private def lit[T](literal: T): Literal[T] = new Literal[T] {
override def value: T = literal
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
override def describe: String = literal.toString
}
/**
* Creates a NamedReference using an anonymous class.
*/
private def ref(names: String*): NamedReference = new NamedReference {
override def fieldNames: Array[String] = names.toArray
override def describe: String = names.mkString(".")
}
/**
* Creates a Transform using an anonymous class.
*/
private def transform(func: String, ref: NamedReference): Transform = new Transform {
override def name: String = func
override def references: Array[NamedReference] = Array(ref)
override def arguments: Array[Expression] = Array(ref)
override def describe: String = ref.describe
}
test("Identity extractor") {
transform("identity", ref("a", "b")) match {
case IdentityTransform(FieldReference(seq)) =>
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match IdentityTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case IdentityTransform(FieldReference(_)) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
test("Years extractor") {
transform("years", ref("a", "b")) match {
case YearsTransform(FieldReference(seq)) =>
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match YearsTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case YearsTransform(FieldReference(_)) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
test("Months extractor") {
transform("months", ref("a", "b")) match {
case MonthsTransform(FieldReference(seq)) =>
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match MonthsTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case MonthsTransform(FieldReference(_)) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
test("Days extractor") {
transform("days", ref("a", "b")) match {
case DaysTransform(FieldReference(seq)) =>
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match DaysTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case DaysTransform(FieldReference(_)) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
test("Hours extractor") {
transform("hours", ref("a", "b")) match {
case HoursTransform(FieldReference(seq)) =>
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match HoursTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case HoursTransform(FieldReference(_)) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
test("Bucket extractor") {
val col = ref("a", "b")
val bucketTransform = new Transform {
override def name: String = "bucket"
override def references: Array[NamedReference] = Array(col)
override def arguments: Array[Expression] = Array(lit(16), col)
override def describe: String = s"bucket(16, ${col.describe})"
}
bucketTransform match {
case BucketTransform(numBuckets, FieldReference(seq)) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
case _ =>
fail("Did not match BucketTransform extractor")
}
transform("unknown", ref("a", "b")) match {
case BucketTransform(_, _) =>
fail("Matched unknown transform")
case _ =>
// expected
}
}
}