[SPARK-27965][SQL] Add extractors for v2 catalog transforms.
## What changes were proposed in this pull request? Add extractors for v2 catalog transforms. These extractors are used to match transforms that are equivalent to Spark's internal case classes. This makes it easier to work with v2 transforms. ## How was this patch tested? Added test suite for the new extractors. Closes #24812 from rdblue/SPARK-27965-add-transform-extractors. Authored-by: Ryan Blue <blue@apache.org> Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
parent
eee3467b1e
commit
b30655bdef
|
@ -94,6 +94,17 @@ private[sql] final case class BucketTransform(
|
|||
override def toString: String = describe
|
||||
}
|
||||
|
||||
private[sql] object BucketTransform {
|
||||
def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match {
|
||||
case NamedTransform("bucket", Seq(
|
||||
Lit(value: Int, IntegerType),
|
||||
Ref(seq: Seq[String]))) =>
|
||||
Some((value, FieldReference(seq)))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class ApplyTransform(
|
||||
name: String,
|
||||
args: Seq[Expression]) extends Transform {
|
||||
|
@ -111,32 +122,104 @@ private[sql] final case class ApplyTransform(
|
|||
override def toString: String = describe
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience extractor for any Literal.
|
||||
*/
|
||||
private object Lit {
|
||||
def unapply[T](literal: Literal[T]): Some[(T, DataType)] = {
|
||||
Some((literal.value, literal.dataType))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience extractor for any NamedReference.
|
||||
*/
|
||||
private object Ref {
|
||||
def unapply(named: NamedReference): Some[Seq[String]] = {
|
||||
Some(named.fieldNames)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience extractor for any Transform.
|
||||
*/
|
||||
private object NamedTransform {
|
||||
def unapply(transform: Transform): Some[(String, Seq[Expression])] = {
|
||||
Some((transform.name, transform.arguments))
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class IdentityTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "identity"
|
||||
override def describe: String = ref.describe
|
||||
}
|
||||
|
||||
private[sql] object IdentityTransform {
|
||||
def unapply(transform: Transform): Option[FieldReference] = transform match {
|
||||
case NamedTransform("identity", Seq(Ref(parts))) =>
|
||||
Some(FieldReference(parts))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class YearsTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "years"
|
||||
}
|
||||
|
||||
private[sql] object YearsTransform {
|
||||
def unapply(transform: Transform): Option[FieldReference] = transform match {
|
||||
case NamedTransform("years", Seq(Ref(parts))) =>
|
||||
Some(FieldReference(parts))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class MonthsTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "months"
|
||||
}
|
||||
|
||||
private[sql] object MonthsTransform {
|
||||
def unapply(transform: Transform): Option[FieldReference] = transform match {
|
||||
case NamedTransform("months", Seq(Ref(parts))) =>
|
||||
Some(FieldReference(parts))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class DaysTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "days"
|
||||
}
|
||||
|
||||
private[sql] object DaysTransform {
|
||||
def unapply(transform: Transform): Option[FieldReference] = transform match {
|
||||
case NamedTransform("days", Seq(Ref(parts))) =>
|
||||
Some(FieldReference(parts))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class HoursTransform(
|
||||
ref: NamedReference) extends SingleColumnTransform(ref) {
|
||||
override val name: String = "hours"
|
||||
}
|
||||
|
||||
private[sql] object HoursTransform {
|
||||
def unapply(transform: Transform): Option[FieldReference] = transform match {
|
||||
case NamedTransform("hours", Seq(Ref(parts))) =>
|
||||
Some(FieldReference(parts))
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
|
||||
override def describe: String = {
|
||||
if (dataType.isInstanceOf[StringType]) {
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalog.v2.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
class TransformExtractorSuite extends SparkFunSuite {
|
||||
/**
|
||||
* Creates a Literal using an anonymous class.
|
||||
*/
|
||||
private def lit[T](literal: T): Literal[T] = new Literal[T] {
|
||||
override def value: T = literal
|
||||
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
|
||||
override def describe: String = literal.toString
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a NamedReference using an anonymous class.
|
||||
*/
|
||||
private def ref(names: String*): NamedReference = new NamedReference {
|
||||
override def fieldNames: Array[String] = names.toArray
|
||||
override def describe: String = names.mkString(".")
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Transform using an anonymous class.
|
||||
*/
|
||||
private def transform(func: String, ref: NamedReference): Transform = new Transform {
|
||||
override def name: String = func
|
||||
override def references: Array[NamedReference] = Array(ref)
|
||||
override def arguments: Array[Expression] = Array(ref)
|
||||
override def describe: String = ref.describe
|
||||
}
|
||||
|
||||
test("Identity extractor") {
|
||||
transform("identity", ref("a", "b")) match {
|
||||
case IdentityTransform(FieldReference(seq)) =>
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match IdentityTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case IdentityTransform(FieldReference(_)) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
test("Years extractor") {
|
||||
transform("years", ref("a", "b")) match {
|
||||
case YearsTransform(FieldReference(seq)) =>
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match YearsTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case YearsTransform(FieldReference(_)) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
test("Months extractor") {
|
||||
transform("months", ref("a", "b")) match {
|
||||
case MonthsTransform(FieldReference(seq)) =>
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match MonthsTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case MonthsTransform(FieldReference(_)) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
test("Days extractor") {
|
||||
transform("days", ref("a", "b")) match {
|
||||
case DaysTransform(FieldReference(seq)) =>
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match DaysTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case DaysTransform(FieldReference(_)) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
test("Hours extractor") {
|
||||
transform("hours", ref("a", "b")) match {
|
||||
case HoursTransform(FieldReference(seq)) =>
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match HoursTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case HoursTransform(FieldReference(_)) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
test("Bucket extractor") {
|
||||
val col = ref("a", "b")
|
||||
val bucketTransform = new Transform {
|
||||
override def name: String = "bucket"
|
||||
override def references: Array[NamedReference] = Array(col)
|
||||
override def arguments: Array[Expression] = Array(lit(16), col)
|
||||
override def describe: String = s"bucket(16, ${col.describe})"
|
||||
}
|
||||
|
||||
bucketTransform match {
|
||||
case BucketTransform(numBuckets, FieldReference(seq)) =>
|
||||
assert(numBuckets === 16)
|
||||
assert(seq === Seq("a", "b"))
|
||||
case _ =>
|
||||
fail("Did not match BucketTransform extractor")
|
||||
}
|
||||
|
||||
transform("unknown", ref("a", "b")) match {
|
||||
case BucketTransform(_, _) =>
|
||||
fail("Matched unknown transform")
|
||||
case _ =>
|
||||
// expected
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue