diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 13a13f0a7e..2e9b6be9a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm import java.io.IOException +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} @@ -26,12 +27,16 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" + override def toString: String = "LibSVM" + private def verifySchema(dataSchema: StructType): Unit = { if (dataSchema.size != 2 || (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) || !dataSchema(1).dataType.sameType(new VectorUDT()))) { - throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } } + override def inferSchema( sqlContext: SQLContext, options: Map[String, String], @@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with DataSourceRegister { StructField("features", new VectorUDT(), nullable = false) :: Nil)) } + override def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = { + def computeNumFeatures(): Int = { + val dataFiles = files.filterNot(_.getPath.getName startsWith "_") + val path = if (dataFiles.length == 1) { + dataFiles.head.getPath.toUri.toString + } else if (dataFiles.isEmpty) { + throw new IOException("No input path specified for libsvm data") + } else { + throw new IOException("Multiple input paths are not supported for libsvm data.") + } + + val sc = sqlContext.sparkContext + val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + MLUtils.computeNumFeatures(parsed) + } + + val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { + computeNumFeatures() + } + + new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + } + override def prepareWrite( sqlContext: SQLContext, job: Job, @@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { verifySchema(dataSchema) val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString + val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") else throw new IOException("Multiple input paths are not supported for libsvm data.") @@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with DataSourceRegister { externalRows.map(converter.toRow) } } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + val numFeatures = options("numFeatures").toInt + assert(numFeatures > 0) + + val sparse = options.getOrElse("vectorType", "sparse") == "sparse" + + val broadcastedConf = sqlContext.sparkContext.broadcast( + new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration)) + ) + + (file: PartitionedFile) => { + val points = + new HadoopFileLinesReader(file, broadcastedConf.value.value) + .map(_.toString.trim) + .filterNot(line => line.isEmpty || line.startsWith("#")) + .map { line => + val (label, indices, values) = MLUtils.parseLibSVMRecord(line) + LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) + } + + val converter = RowEncoder(requiredSchema) + + val unsafeRowIterator = points.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + converter.toRow(Row(pt.label, features)) + } + + def toAttribute(f: StructField): AttributeReference = + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + + // Appends partition values + val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute) + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index c3b1d5cdd7..4b9d77949f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -67,42 +67,14 @@ object MLUtils { path: String, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { - val parsed = sc.textFile(path, minPartitions) - .map(_.trim) - .filter(line => !(line.isEmpty || line.startsWith("#"))) - .map { line => - val items = line.split(' ') - val label = items.head.toDouble - val (indices, values) = items.tail.filter(_.nonEmpty).map { item => - val indexAndValue = item.split(':') - val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. - val value = indexAndValue(1).toDouble - (index, value) - }.unzip - - // check if indices are one-based and in ascending order - var previous = -1 - var i = 0 - val indicesLength = indices.length - while (i < indicesLength) { - val current = indices(i) - require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") - previous = current - i += 1 - } - - (label, indices.toArray, values.toArray) - } + val parsed = parseLibSVMFile(sc, path, minPartitions) // Determine number of features. val d = if (numFeatures > 0) { numFeatures } else { parsed.persist(StorageLevel.MEMORY_ONLY) - parsed.map { case (label, indices, values) => - indices.lastOption.getOrElse(0) - }.reduce(math.max) + 1 + computeNumFeatures(parsed) } parsed.map { case (label, indices, values) => @@ -110,6 +82,47 @@ object MLUtils { } } + private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = { + rdd.map { case (label, indices, values) => + indices.lastOption.getOrElse(0) + }.reduce(math.max) + 1 + } + + private[spark] def parseLibSVMFile( + sc: SparkContext, + path: String, + minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = { + sc.textFile(path, minPartitions) + .map(_.trim) + .filter(line => !(line.isEmpty || line.startsWith("#"))) + .map(parseLibSVMRecord) + } + + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { + val items = line.split(' ') + val label = items.head.toDouble + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") + previous = current + i += 1 + } + + (label, indices, values) + } + /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c66921f485..1850810270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -299,6 +299,9 @@ case class DataSource( "It must be specified manually") } + val enrichedOptions = + format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles()) + HadoopFsRelation( sqlContext, fileCatalog, @@ -306,7 +309,7 @@ case class DataSource( dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, - options) + enrichedOptions) case _ => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 554298772a..a143ac6aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { if (files.fileFormat.toString == "TestFileFormat" || files.fileFormat.isInstanceOf[parquet.DefaultSource] || files.fileFormat.toString == "ORC" || + files.fileFormat.toString == "LibSVM" || files.fileFormat.isInstanceOf[csv.DefaultSource] || files.fileFormat.isInstanceOf[text.DefaultSource] || files.fileFormat.isInstanceOf[json.DefaultSource]) && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6b95a3d25b..e8834d052c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -438,6 +438,15 @@ trait FileFormat { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] + /** + * Prepares a read job and returns a potentially updated data source option [[Map]]. This method + * can be useful for collecting necessary global information for scanning input data. + */ + def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = options + /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here