[SPARK-10117] [MLLIB] Implement SQL data source API for reading LIBSVM data

It is convenient to implement data source API for LIBSVM format to have a better integration with DataFrames and ML pipeline API.

Two option is implemented.
* `numFeatures`: Specify the dimension of features vector
* `featuresType`: Specify the type of output vector. `sparse` is default.

Author: lewuathe <lewuathe@me.com>

Closes #8537 from Lewuathe/SPARK-10117 and squashes the following commits:

986999d [lewuathe] Change unit test phrase
11d513f [lewuathe] Fix some reviews
21600a4 [lewuathe] Merge branch 'master' into SPARK-10117
9ce63c7 [lewuathe] Rewrite service loader file
1fdd2df [lewuathe] Merge branch 'SPARK-10117' of github.com:Lewuathe/spark into SPARK-10117
ba3657c [lewuathe] Merge branch 'master' into SPARK-10117
0ea1c1c [lewuathe] LibSVMRelation is registered into META-INF
4f40891 [lewuathe] Improve test suites
5ab62ab [lewuathe] Merge branch 'master' into SPARK-10117
8660d0e [lewuathe] Fix Java unit test
b56a948 [lewuathe] Merge branch 'master' into SPARK-10117
2c12894 [lewuathe] Remove unnecessary tag
7d693c2 [lewuathe] Resolv conflict
62010af [lewuathe] Merge branch 'master' into SPARK-10117
a97ee97 [lewuathe] Fix some points
aef9564 [lewuathe] Fix
70ee4dd [lewuathe] Add Java test
3fd8dce [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data
40d3027 [lewuathe] Add Java test
7056d4a [lewuathe] Merge branch 'master' into SPARK-10117
99accaa [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data
This commit is contained in:
lewuathe 2015-09-09 09:29:10 -07:00 committed by Xiangrui Meng
parent c1bc4f439f
commit 2ddeb63126
4 changed files with 256 additions and 0 deletions

View file

@ -0,0 +1 @@
org.apache.spark.ml.source.libsvm.DefaultSource

View file

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.source.libsvm
import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources._
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
* @param path File path of LibSVM format
* @param numFeatures The number of features
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging with Serializable {
override def schema: StructType = StructType(
StructField("label", DoubleType, nullable = false) ::
StructField("features", new VectorUDT(), nullable = false) :: Nil
)
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
baseRdd.map { pt =>
val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
Row(pt.label, features)
}
}
override def hashCode(): Int = {
Objects.hashCode(path, schema)
}
override def equals(other: Any): Boolean = other match {
case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
case _ => false
}
}
/**
* This is used for creating DataFrame from LibSVM format file.
* The LibSVM file path must be specified to DefaultSource.
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
private def checkPath(parameters: Map[String, String]): String = {
require(parameters.contains("path"), "'path' must be specified")
parameters.get("path").get
}
/**
* Returns a new base relation with the given parameters.
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
* by the Map that is passed to the function.
*/
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
val path = checkPath(parameters)
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
/**
* featuresType can be selected "dense" or "sparse".
* This parameter decides the type of returned feature vector.
*/
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
}

View file

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.source;
import java.io.File;
import java.io.IOException;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
/**
* Test LibSVMRelation in Java.
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
private transient DataFrame dataset;
private File tmpDir;
private File path;
@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
jsql = new SQLContext(jsc);
tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
path = new File(tmpDir.getPath(), "part-00000");
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
Files.write(s, path, Charsets.US_ASCII);
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
Utils.deleteRecursively(tmpDir);
}
@Test
public void verifyLibSVMDF() {
dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath());
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
Assert.assertEquals(1.0, r.getDouble(0), 1e-15);
DenseVector v = r.getAs(1);
Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v);
}
}

View file

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.source
import java.io.File
import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
var path: String = _
override def beforeAll(): Unit = {
super.beforeAll()
val lines =
"""
|1 1:1.0 3:2.0 5:3.0
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
path = tempDir.toURI.toString
}
test("select as sparse vector") {
val df = sqlContext.read.format("libsvm").load(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
test("select as dense vector") {
val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense"))
.load(path)
assert(df.columns(0) == "label")
assert(df.columns(1) == "features")
assert(df.count() == 3)
val row1 = df.first()
assert(row1.getDouble(0) == 1.0)
val v = row1.getAs[DenseVector](1)
assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
}
test("select a vector with specifying the longer dimension") {
val df = sqlContext.read.option("numFeatures", "100").format("libsvm")
.load(path)
val row1 = df.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
}