[SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples
We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng <meng@databricks.com> Closes #9873 from mengxr/SPARK-11895.
This commit is contained in:
parent
426004a9c9
commit
fe89c1817d
|
@ -16,7 +16,7 @@
|
|||
*/
|
||||
|
||||
// scalastyle:off println
|
||||
package org.apache.spark.examples.mllib
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import java.io.File
|
||||
|
||||
|
@ -24,25 +24,22 @@ import com.google.common.io.Files
|
|||
import scopt.OptionParser
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.examples.mllib.AbstractParams
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Row, SQLContext, DataFrame}
|
||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
||||
|
||||
/**
|
||||
* An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
|
||||
* An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with
|
||||
* {{{
|
||||
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
|
||||
* ./bin/run-example ml.DataFrameExample [options]
|
||||
* }}}
|
||||
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
|
||||
*/
|
||||
object DatasetExample {
|
||||
object DataFrameExample {
|
||||
|
||||
case class Params(
|
||||
input: String = "data/mllib/sample_libsvm_data.txt",
|
||||
dataFormat: String = "libsvm") extends AbstractParams[Params]
|
||||
case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
|
||||
extends AbstractParams[Params]
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val defaultParams = Params()
|
||||
|
@ -52,9 +49,6 @@ object DatasetExample {
|
|||
opt[String]("input")
|
||||
.text(s"input path to dataset")
|
||||
.action((x, c) => c.copy(input = x))
|
||||
opt[String]("dataFormat")
|
||||
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
|
||||
.action((x, c) => c.copy(input = x))
|
||||
checkConfig { params =>
|
||||
success
|
||||
}
|
||||
|
@ -69,55 +63,42 @@ object DatasetExample {
|
|||
|
||||
def run(params: Params) {
|
||||
|
||||
val conf = new SparkConf().setAppName(s"DatasetExample with $params")
|
||||
val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._ // for implicit conversions
|
||||
|
||||
// Load input data
|
||||
val origData: RDD[LabeledPoint] = params.dataFormat match {
|
||||
case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
|
||||
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
|
||||
}
|
||||
println(s"Loaded ${origData.count()} instances from file: ${params.input}")
|
||||
println(s"Loading LIBSVM file with UDT from ${params.input}.")
|
||||
val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
|
||||
println("Schema from LIBSVM:")
|
||||
df.printSchema()
|
||||
println(s"Loaded training data as a DataFrame with ${df.count()} records.")
|
||||
|
||||
// Convert input data to DataFrame explicitly.
|
||||
val df: DataFrame = origData.toDF()
|
||||
println(s"Inferred schema:\n${df.schema.prettyJson}")
|
||||
println(s"Converted to DataFrame with ${df.count()} records")
|
||||
// Show statistical summary of labels.
|
||||
val labelSummary = df.describe("label")
|
||||
labelSummary.show()
|
||||
|
||||
// Select columns
|
||||
val labelsDf: DataFrame = df.select("label")
|
||||
val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
|
||||
val numLabels = labels.count()
|
||||
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
|
||||
println(s"Selected label column with average value $meanLabel")
|
||||
|
||||
val featuresDf: DataFrame = df.select("features")
|
||||
val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
|
||||
// Convert features column to an RDD of vectors.
|
||||
val features = df.select("features").map { case Row(v: Vector) => v }
|
||||
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
|
||||
(summary, feat) => summary.add(feat),
|
||||
(sum1, sum2) => sum1.merge(sum2))
|
||||
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
|
||||
|
||||
// Save the records in a parquet file.
|
||||
val tmpDir = Files.createTempDir()
|
||||
tmpDir.deleteOnExit()
|
||||
val outputDir = new File(tmpDir, "dataset").toString
|
||||
println(s"Saving to $outputDir as Parquet file.")
|
||||
df.write.parquet(outputDir)
|
||||
|
||||
// Load the records back.
|
||||
println(s"Loading Parquet file with UDT from $outputDir.")
|
||||
val newDataset = sqlContext.read.parquet(outputDir)
|
||||
|
||||
println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
|
||||
val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
|
||||
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
|
||||
(summary, feat) => summary.add(feat),
|
||||
(sum1, sum2) => sum1.merge(sum2))
|
||||
println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
|
||||
val newDF = sqlContext.read.parquet(outputDir)
|
||||
println(s"Schema from Parquet:")
|
||||
newDF.printSchema()
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
}
|
||||
// scalastyle:on println
|
Loading…
Reference in a new issue