[SPARK-34432][SQL][TESTS] Add JavaSimpleWritableDataSource
### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/19269 In #19269 , there is only a scala implementation of simple writable data source in `DataSourceV2Suite`. This PR adds a java implementation of it. ### Why are the changes needed? To improve test coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #31560 from kevincmchen/SPARK-34432. Lead-authored-by: kevincmchen <kevincmchen@tencent.com> Co-authored-by: Kevin Pis <68981916+kevincmchen@users.noreply.github.com> Co-authored-by: Kevin Pis <kc4163568@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
23a5996a46
commit
9767041153
|
@ -17,7 +17,7 @@
|
|||
|
||||
package test.org.apache.spark.sql.connector;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -28,11 +28,8 @@ import org.apache.spark.sql.connector.catalog.TableCapability;
|
|||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
abstract class JavaSimpleBatchTable implements Table, SupportsRead {
|
||||
private static final Set<TableCapability> CAPABILITIES = new HashSet<>(Arrays.asList(
|
||||
TableCapability.BATCH_READ,
|
||||
TableCapability.BATCH_WRITE,
|
||||
TableCapability.TRUNCATE));
|
||||
|
||||
private static final Set<TableCapability> CAPABILITIES =
|
||||
new HashSet<>(Collections.singletonList(TableCapability.BATCH_READ));
|
||||
@Override
|
||||
public StructType schema() {
|
||||
return TestingV2Source.schema();
|
||||
|
|
|
@ -0,0 +1,371 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql.connector;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.*;
|
||||
|
||||
import org.apache.spark.deploy.SparkHadoopUtil;
|
||||
import org.apache.spark.sql.catalyst.InternalRow;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
|
||||
import org.apache.spark.sql.connector.SimpleCounter;
|
||||
import org.apache.spark.sql.connector.TestingV2Source;
|
||||
import org.apache.spark.sql.connector.catalog.SessionConfigSupport;
|
||||
import org.apache.spark.sql.connector.catalog.SupportsWrite;
|
||||
import org.apache.spark.sql.connector.catalog.Table;
|
||||
import org.apache.spark.sql.connector.catalog.TableCapability;
|
||||
import org.apache.spark.sql.connector.read.InputPartition;
|
||||
import org.apache.spark.sql.connector.read.PartitionReader;
|
||||
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
|
||||
import org.apache.spark.sql.connector.read.ScanBuilder;
|
||||
import org.apache.spark.sql.connector.write.*;
|
||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
|
||||
import org.apache.spark.util.SerializableConfiguration;
|
||||
|
||||
/**
|
||||
* A HDFS based transactional writable data source which is implemented by java.
|
||||
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
|
||||
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
|
||||
*/
|
||||
public class JavaSimpleWritableDataSource implements TestingV2Source, SessionConfigSupport {
|
||||
|
||||
@Override
|
||||
public String keyPrefix() {
|
||||
return "javaSimpleWritableDataSource";
|
||||
}
|
||||
|
||||
static class MyScanBuilder extends JavaSimpleScanBuilder {
|
||||
|
||||
private final String path;
|
||||
private final Configuration conf;
|
||||
|
||||
MyScanBuilder(String path, Configuration conf) {
|
||||
this.path = path;
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPartition[] planInputPartitions() {
|
||||
Path dataPath = new Path(this.path);
|
||||
try {
|
||||
FileSystem fs = dataPath.getFileSystem(conf);
|
||||
if (fs.exists(dataPath)) {
|
||||
return Arrays.stream(fs.listStatus(dataPath))
|
||||
.filter(
|
||||
status -> {
|
||||
String name = status.getPath().getName();
|
||||
return !name.startsWith("_") && !name.startsWith(".");
|
||||
})
|
||||
.map(f -> new JavaCSVInputPartitionReader(f.getPath().toUri().toString()))
|
||||
.toArray(InputPartition[]::new);
|
||||
} else {
|
||||
return new InputPartition[0];
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public PartitionReaderFactory createReaderFactory() {
|
||||
SerializableConfiguration serializableConf = new SerializableConfiguration(conf);
|
||||
return new JavaCSVReaderFactory(serializableConf);
|
||||
}
|
||||
}
|
||||
|
||||
static class MyWriteBuilder implements WriteBuilder, SupportsTruncate {
|
||||
|
||||
private final String path;
|
||||
private final String queryId;
|
||||
private boolean needTruncate = false;
|
||||
|
||||
MyWriteBuilder(String path, LogicalWriteInfo info) {
|
||||
this.path = path;
|
||||
this.queryId = info.queryId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public WriteBuilder truncate() {
|
||||
this.needTruncate = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Write build() {
|
||||
return new MyWrite(path, queryId, needTruncate);
|
||||
}
|
||||
}
|
||||
|
||||
static class MyWrite implements Write {
|
||||
|
||||
private final String path;
|
||||
private final String queryId;
|
||||
private final boolean needTruncate;
|
||||
|
||||
MyWrite(String path, String queryId, boolean needTruncate) {
|
||||
this.path = path;
|
||||
this.queryId = queryId;
|
||||
this.needTruncate = needTruncate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BatchWrite toBatch() {
|
||||
Path hadoopPath = new Path(path);
|
||||
Configuration hadoopConf = SparkHadoopUtil.get().conf();
|
||||
try {
|
||||
FileSystem fs = hadoopPath.getFileSystem(hadoopConf);
|
||||
if (needTruncate) {
|
||||
fs.delete(hadoopPath, true);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
String pathStr = hadoopPath.toUri().toString();
|
||||
return new MyBatchWrite(queryId, pathStr, hadoopConf);
|
||||
}
|
||||
}
|
||||
|
||||
static class MyBatchWrite implements BatchWrite {
|
||||
|
||||
private final String queryId;
|
||||
private final String path;
|
||||
private final Configuration conf;
|
||||
|
||||
MyBatchWrite(String queryId, String path, Configuration conf) {
|
||||
this.queryId = queryId;
|
||||
this.path = path;
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
|
||||
SimpleCounter.resetCounter();
|
||||
return new JavaCSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDataWriterCommit(WriterCommitMessage message) {
|
||||
SimpleCounter.increaseCounter();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void commit(WriterCommitMessage[] messages) {
|
||||
Path finalPath = new Path(this.path);
|
||||
Path jobPath = new Path(new Path(finalPath, "_temporary"), queryId);
|
||||
try {
|
||||
FileSystem fs = jobPath.getFileSystem(conf);
|
||||
FileStatus[] fileStatuses = fs.listStatus(jobPath);
|
||||
try {
|
||||
for (FileStatus status : fileStatuses) {
|
||||
Path file = status.getPath();
|
||||
Path dest = new Path(finalPath, file.getName());
|
||||
if (!fs.rename(file, dest)) {
|
||||
throw new IOException(String.format("failed to rename(%s, %s)", file, dest));
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
fs.delete(jobPath, true);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void abort(WriterCommitMessage[] messages) {
|
||||
try {
|
||||
Path jobPath = new Path(new Path(this.path, "_temporary"), queryId);
|
||||
FileSystem fs = jobPath.getFileSystem(conf);
|
||||
fs.delete(jobPath, true);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class MyTable extends JavaSimpleBatchTable implements SupportsWrite {
|
||||
|
||||
private final String path;
|
||||
private final Configuration conf = SparkHadoopUtil.get().conf();
|
||||
|
||||
MyTable(CaseInsensitiveStringMap options) {
|
||||
this.path = options.get("path");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
|
||||
return new MyScanBuilder(new Path(path).toUri().toString(), conf);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
|
||||
return new MyWriteBuilder(path, info);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<TableCapability> capabilities() {
|
||||
return new HashSet<>(Arrays.asList(
|
||||
TableCapability.BATCH_READ,
|
||||
TableCapability.BATCH_WRITE,
|
||||
TableCapability.TRUNCATE));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Table getTable(CaseInsensitiveStringMap options) {
|
||||
return new MyTable(options);
|
||||
}
|
||||
|
||||
static class JavaCSVInputPartitionReader implements InputPartition {
|
||||
|
||||
private String path;
|
||||
|
||||
JavaCSVInputPartitionReader(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
|
||||
public String getPath() {
|
||||
return path;
|
||||
}
|
||||
|
||||
public void setPath(String path) {
|
||||
this.path = path;
|
||||
}
|
||||
}
|
||||
|
||||
static class JavaCSVReaderFactory implements PartitionReaderFactory {
|
||||
|
||||
private final SerializableConfiguration conf;
|
||||
|
||||
JavaCSVReaderFactory(SerializableConfiguration conf) {
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PartitionReader<InternalRow> createReader(InputPartition partition) {
|
||||
String path = ((JavaCSVInputPartitionReader) partition).getPath();
|
||||
Path filePath = new Path(path);
|
||||
try {
|
||||
FileSystem fs = filePath.getFileSystem(conf.value());
|
||||
return new PartitionReader<InternalRow>() {
|
||||
private final FSDataInputStream inputStream = fs.open(filePath);
|
||||
private final Iterator<String> lines =
|
||||
new BufferedReader(new InputStreamReader(inputStream)).lines().iterator();
|
||||
private String currentLine = "";
|
||||
|
||||
@Override
|
||||
public boolean next() {
|
||||
if (lines.hasNext()) {
|
||||
currentLine = lines.next();
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InternalRow get() {
|
||||
Object[] objects =
|
||||
Arrays.stream(currentLine.split(","))
|
||||
.map(String::trim)
|
||||
.map(Integer::parseInt)
|
||||
.toArray();
|
||||
return new GenericInternalRow(objects);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
inputStream.close();
|
||||
}
|
||||
};
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class JavaCSVDataWriterFactory implements DataWriterFactory {
|
||||
|
||||
private final String path;
|
||||
private final String jobId;
|
||||
private final SerializableConfiguration conf;
|
||||
|
||||
JavaCSVDataWriterFactory(String path, String jobId, SerializableConfiguration conf) {
|
||||
this.path = path;
|
||||
this.jobId = jobId;
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
|
||||
try {
|
||||
Path jobPath = new Path(new Path(path, "_temporary"), jobId);
|
||||
Path filePath = new Path(jobPath, String.format("%s-%d-%d", jobId, partitionId, taskId));
|
||||
FileSystem fs = filePath.getFileSystem(conf.value());
|
||||
return new JavaCSVDataWriter(fs, filePath);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class JavaCSVDataWriter implements DataWriter<InternalRow> {
|
||||
|
||||
private final FileSystem fs;
|
||||
private final Path file;
|
||||
private final FSDataOutputStream out;
|
||||
|
||||
JavaCSVDataWriter(FileSystem fs, Path file) throws IOException {
|
||||
this.fs = fs;
|
||||
this.file = file;
|
||||
out = fs.create(file);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(InternalRow record) throws IOException {
|
||||
out.writeBytes(String.format("%d,%d\n", record.getInt(0), record.getInt(1)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public WriterCommitMessage commit() throws IOException {
|
||||
out.close();
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void abort() throws IOException {
|
||||
try {
|
||||
out.close();
|
||||
} finally {
|
||||
fs.delete(file, false);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
}
|
||||
}
|
||||
}
|
|
@ -228,8 +228,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
|
|||
}
|
||||
|
||||
test("simple writable data source") {
|
||||
// TODO: java implementation.
|
||||
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
|
||||
Seq(classOf[SimpleWritableDataSource], classOf[JavaSimpleWritableDataSource]).foreach { cls =>
|
||||
withTempPath { file =>
|
||||
val path = file.getCanonicalPath
|
||||
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
|
||||
|
|
Loading…
Reference in a new issue