[SPARK-15031][EXAMPLE] Use SparkSession in examples

## What changes were proposed in this pull request?
Use `SparkSession` according to [SPARK-15031](https://issues.apache.org/jira/browse/SPARK-15031)

`MLLLIB` is not recommended to use now, so examples in `MLLIB` are ignored in this PR.
`StreamingContext` can not be directly obtained from `SparkSession`, so example in `Streaming` are ignored too.

cc andrewor14

## How was this patch tested?
manual tests with spark-submit

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #13164 from zhengruifeng/use_sparksession_ii.
This commit is contained in:
Zheng RuiFeng 2016-05-20 16:40:33 -07:00 committed by Andrew Or
parent 06c9f52071
commit 127bf1bb07
34 changed files with 279 additions and 146 deletions

View file

@ -17,11 +17,10 @@
package org.apache.spark.examples;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
import java.util.Arrays;
@ -122,9 +121,12 @@ public final class JavaHdfsLR {
showWarning();
SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(args[0]);
SparkSession spark = SparkSession
.builder()
.appName("JavaHdfsLR")
.getOrCreate();
JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
int ITERATIONS = Integer.parseInt(args[1]);
@ -152,6 +154,6 @@ public final class JavaHdfsLR {
System.out.print("Final w: ");
printWeights(w);
sc.stop();
spark.stop();
}
}

View file

@ -20,12 +20,13 @@ package org.apache.spark.examples;
import com.google.common.collect.Lists;
import scala.Tuple2;
import scala.Tuple3;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
import java.util.List;
@ -99,9 +100,12 @@ public final class JavaLogQuery {
}
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaLogQuery")
.getOrCreate();
SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
JavaRDD<String> dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs);
@ -123,6 +127,6 @@ public final class JavaLogQuery {
for (Tuple2<?,?> t : output) {
System.out.println(t._1() + "\t" + t._2());
}
jsc.stop();
spark.stop();
}
}

View file

@ -26,14 +26,13 @@ import scala.Tuple2;
import com.google.common.collect.Iterables;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
/**
* Computes the PageRank of URLs from an input file. Input file should
@ -73,15 +72,17 @@ public final class JavaPageRank {
showWarning();
SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
SparkSession spark = SparkSession
.builder()
.appName("JavaPageRank")
.getOrCreate();
// Loads in input file. It should be in format of:
// URL neighbor URL
// URL neighbor URL
// URL neighbor URL
// ...
JavaRDD<String> lines = ctx.textFile(args[0], 1);
JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
// Loads all URLs from input file and initialize their neighbors.
JavaPairRDD<String, Iterable<String>> links = lines.mapToPair(
@ -132,6 +133,6 @@ public final class JavaPageRank {
System.out.println(tuple._1() + " has rank: " + tuple._2() + ".");
}
ctx.stop();
spark.stop();
}
}

View file

@ -17,11 +17,11 @@
package org.apache.spark.examples;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.SparkSession;
import java.util.ArrayList;
import java.util.List;
@ -33,8 +33,12 @@ import java.util.List;
public final class JavaSparkPi {
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
SparkSession spark = SparkSession
.builder()
.appName("JavaSparkPi")
.getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2;
int n = 100000 * slices;
@ -61,6 +65,6 @@ public final class JavaSparkPi {
System.out.println("Pi is roughly " + 4.0 * count / n);
jsc.stop();
spark.stop();
}
}

View file

@ -17,13 +17,14 @@
package org.apache.spark.examples;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkJobInfo;
import org.apache.spark.SparkStageInfo;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
import java.util.List;
@ -44,11 +45,15 @@ public final class JavaStatusTrackerDemo {
}
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName(APP_NAME);
final JavaSparkContext sc = new JavaSparkContext(sparkConf);
SparkSession spark = SparkSession
.builder()
.appName(APP_NAME)
.getOrCreate();
final JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
// Example of implementing a progress reporter for a simple job.
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
JavaRDD<Integer> rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
new IdentityWithDelay<Integer>());
JavaFutureAction<List<Integer>> jobFuture = rdd.collectAsync();
while (!jobFuture.isDone()) {
@ -58,13 +63,13 @@ public final class JavaStatusTrackerDemo {
continue;
}
int currentJobId = jobIds.get(jobIds.size() - 1);
SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId);
SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]);
SparkJobInfo jobInfo = jsc.statusTracker().getJobInfo(currentJobId);
SparkStageInfo stageInfo = jsc.statusTracker().getStageInfo(jobInfo.stageIds()[0]);
System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() +
" active, " + stageInfo.numCompletedTasks() + " complete");
}
System.out.println("Job results are: " + jobFuture.get());
sc.stop();
spark.stop();
}
}

View file

@ -25,10 +25,10 @@ import java.util.Set;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
/**
* Transitive closure on a graph, implemented in Java.
@ -64,10 +64,15 @@ public final class JavaTC {
}
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
SparkSession spark = SparkSession
.builder()
.appName("JavaTC")
.getOrCreate();
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2;
JavaPairRDD<Integer, Integer> tc = sc.parallelizePairs(generateGraph(), slices).cache();
JavaPairRDD<Integer, Integer> tc = jsc.parallelizePairs(generateGraph(), slices).cache();
// Linear transitive closure: each round grows paths by one edge,
// by joining the graph's edges with the already-discovered paths.
@ -94,6 +99,6 @@ public final class JavaTC {
} while (nextCount != oldCount);
System.out.println("TC has " + tc.count() + " edges.");
sc.stop();
spark.stop();
}
}

View file

@ -18,13 +18,13 @@
package org.apache.spark.examples;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
import java.util.Iterator;
@ -41,9 +41,12 @@ public final class JavaWordCount {
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = ctx.textFile(args[0], 1);
SparkSession spark = SparkSession
.builder()
.appName("JavaWordCount")
.getOrCreate();
JavaRDD<String> lines = spark.read().text(args[0]).javaRDD();
JavaRDD<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
@ -72,6 +75,6 @@ public final class JavaWordCount {
for (Tuple2<?,?> tuple : output) {
System.out.println(tuple._1() + ": " + tuple._2());
}
ctx.stop();
spark.stop();
}
}

View file

@ -28,7 +28,7 @@ import sys
import numpy as np
from numpy.random import rand
from numpy import matrix
from pyspark import SparkContext
from pyspark.sql import SparkSession
LAMBDA = 0.01 # regularization
np.random.seed(42)
@ -62,7 +62,13 @@ if __name__ == "__main__":
example. Please use pyspark.ml.recommendation.ALS for more
conventional use.""", file=sys.stderr)
sc = SparkContext(appName="PythonALS")
spark = SparkSession\
.builder\
.appName("PythonALS")\
.getOrCreate()
sc = spark._sc
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
U = int(sys.argv[2]) if len(sys.argv) > 2 else 500
F = int(sys.argv[3]) if len(sys.argv) > 3 else 10
@ -99,4 +105,4 @@ if __name__ == "__main__":
print("Iteration %d:" % i)
print("\nRMSE: %5.4f\n" % error)
sc.stop()
spark.stop()

View file

@ -19,8 +19,8 @@ from __future__ import print_function
import sys
from pyspark import SparkContext
from functools import reduce
from pyspark.sql import SparkSession
"""
Read data file users.avro in local Spark distro:
@ -64,7 +64,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
sc = SparkContext(appName="AvroKeyInputFormat")
spark = SparkSession\
.builder\
.appName("AvroKeyInputFormat")\
.getOrCreate()
sc = spark._sc
conf = None
if len(sys.argv) == 3:
@ -82,4 +88,4 @@ if __name__ == "__main__":
for k in output:
print(k)
sc.stop()
spark.stop()

View file

@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
from pyspark import SparkContext
from pyspark.sql import SparkSession
def parseVector(line):
@ -55,8 +55,12 @@ if __name__ == "__main__":
as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
example on how to use ML's KMeans implementation.""", file=sys.stderr)
sc = SparkContext(appName="PythonKMeans")
lines = sc.textFile(sys.argv[1])
spark = SparkSession\
.builder\
.appName("PythonKMeans")\
.getOrCreate()
lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
data = lines.map(parseVector).cache()
K = int(sys.argv[2])
convergeDist = float(sys.argv[3])
@ -79,4 +83,4 @@ if __name__ == "__main__":
print("Final centers: " + str(kPoints))
sc.stop()
spark.stop()

View file

@ -27,7 +27,7 @@ from __future__ import print_function
import sys
import numpy as np
from pyspark import SparkContext
from pyspark.sql import SparkSession
D = 10 # Number of dimensions
@ -55,8 +55,13 @@ if __name__ == "__main__":
Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
to see how ML's implementation is used.""", file=sys.stderr)
sc = SparkContext(appName="PythonLR")
points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache()
spark = SparkSession\
.builder\
.appName("PythonLR")\
.getOrCreate()
points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
.mapPartitions(readPointBatch).cache()
iterations = int(sys.argv[2])
# Initialize w to a random value
@ -80,4 +85,4 @@ if __name__ == "__main__":
print("Final w: " + str(w))
sc.stop()
spark.stop()

View file

@ -25,7 +25,7 @@ import re
import sys
from operator import add
from pyspark import SparkContext
from pyspark.sql import SparkSession
def computeContribs(urls, rank):
@ -51,14 +51,17 @@ if __name__ == "__main__":
file=sys.stderr)
# Initialize the spark context.
sc = SparkContext(appName="PythonPageRank")
spark = SparkSession\
.builder\
.appName("PythonPageRank")\
.getOrCreate()
# Loads in input file. It should be in format of:
# URL neighbor URL
# URL neighbor URL
# URL neighbor URL
# ...
lines = sc.textFile(sys.argv[1], 1)
lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
# Loads all URLs from input file and initialize their neighbors.
links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache()
@ -79,4 +82,4 @@ if __name__ == "__main__":
for (link, rank) in ranks.collect():
print("%s has rank: %s." % (link, rank))
sc.stop()
spark.stop()

View file

@ -18,7 +18,7 @@ from __future__ import print_function
import sys
from pyspark import SparkContext
from pyspark.sql import SparkSession
"""
Read data file users.parquet in local Spark distro:
@ -47,7 +47,13 @@ if __name__ == "__main__":
exit(-1)
path = sys.argv[1]
sc = SparkContext(appName="ParquetInputFormat")
spark = SparkSession\
.builder\
.appName("ParquetInputFormat")\
.getOrCreate()
sc = spark._sc
parquet_rdd = sc.newAPIHadoopFile(
path,
@ -59,4 +65,4 @@ if __name__ == "__main__":
for k in output:
print(k)
sc.stop()
spark.stop()

View file

@ -20,14 +20,20 @@ import sys
from random import random
from operator import add
from pyspark import SparkContext
from pyspark.sql import SparkSession
if __name__ == "__main__":
"""
Usage: pi [partitions]
"""
sc = SparkContext(appName="PythonPi")
spark = SparkSession\
.builder\
.appName("PythonPi")\
.getOrCreate()
sc = spark._sc
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
n = 100000 * partitions
@ -39,4 +45,4 @@ if __name__ == "__main__":
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
print("Pi is roughly %f" % (4.0 * count / n))
sc.stop()
spark.stop()

View file

@ -19,15 +19,20 @@ from __future__ import print_function
import sys
from pyspark import SparkContext
from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: sort <file>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonSort")
lines = sc.textFile(sys.argv[1], 1)
spark = SparkSession\
.builder\
.appName("PythonSort")\
.getOrCreate()
lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
sortedCount = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (int(x), 1)) \
.sortByKey()
@ -37,4 +42,4 @@ if __name__ == "__main__":
for (num, unitcount) in output:
print(num)
sc.stop()
spark.stop()

View file

@ -20,7 +20,7 @@ from __future__ import print_function
import sys
from random import Random
from pyspark import SparkContext
from pyspark.sql import SparkSession
numEdges = 200
numVertices = 100
@ -41,7 +41,13 @@ if __name__ == "__main__":
"""
Usage: transitive_closure [partitions]
"""
sc = SparkContext(appName="PythonTransitiveClosure")
spark = SparkSession\
.builder\
.appName("PythonTransitiveClosure")\
.getOrCreate()
sc = spark._sc
partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2
tc = sc.parallelize(generateGraph(), partitions).cache()
@ -67,4 +73,4 @@ if __name__ == "__main__":
print("TC has %i edges" % tc.count())
sc.stop()
spark.stop()

View file

@ -20,15 +20,20 @@ from __future__ import print_function
import sys
from operator import add
from pyspark import SparkContext
from pyspark.sql import SparkSession
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: wordcount <file>", file=sys.stderr)
exit(-1)
sc = SparkContext(appName="PythonWordCount")
lines = sc.textFile(sys.argv[1], 1)
spark = SparkSession\
.builder\
.appName("PythonWordCount")\
.getOrCreate()
lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
counts = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (x, 1)) \
.reduceByKey(add)
@ -36,4 +41,4 @@ if __name__ == "__main__":
for (word, count) in output:
print("%s: %i" % (word, count))
sc.stop()
spark.stop()

View file

@ -18,7 +18,8 @@
// scalastyle:off println
package org.apache.spark.examples
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* Usage: BroadcastTest [slices] [numElem] [blockSize]
@ -28,9 +29,16 @@ object BroadcastTest {
val blockSize = if (args.length > 2) args(2) else "4096"
val sparkConf = new SparkConf().setAppName("Broadcast Test")
val sparkConf = new SparkConf()
.set("spark.broadcast.blockSize", blockSize)
val sc = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.config(sparkConf)
.appName("Broadcast Test")
.getOrCreate()
val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val num = if (args.length > 1) args(1).toInt else 1000000
@ -48,7 +56,7 @@ object BroadcastTest {
println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -22,7 +22,7 @@ import java.io.File
import scala.io.Source._
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Simple test for reading and writing to a distributed
@ -101,11 +101,14 @@ object DFSReadWriteTest {
val fileContents = readFile(localFilePath.toString())
val localWordCount = runLocalWordCount(fileContents)
println("Creating SparkConf")
val conf = new SparkConf().setAppName("DFS Read Write Test")
println("Creating SparkSession")
val spark = SparkSession
.builder
.appName("DFS Read Write Test")
.getOrCreate()
println("Creating SparkContext")
val sc = new SparkContext(conf)
val sc = spark.sparkContext
println("Writing local file to DFS")
val dfsFilename = dfsDirPath + "/dfs_read_write_test"
@ -124,7 +127,7 @@ object DFSReadWriteTest {
.values
.sum
sc.stop()
spark.stop()
if (localWordCount == dfsWordCount) {
println(s"Success! Local Word Count ($localWordCount) " +

View file

@ -17,18 +17,22 @@
package org.apache.spark.examples
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
object ExceptionHandlingTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest")
val sc = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.appName("ExceptionHandlingTest")
.getOrCreate()
val sc = spark.sparkContext
sc.parallelize(0 until sc.defaultParallelism).foreach { i =>
if (math.random > 0.75) {
throw new Exception("Testing exception handling")
}
}
sc.stop()
spark.stop()
}
}

View file

@ -20,20 +20,24 @@ package org.apache.spark.examples
import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object GroupByTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("GroupBy Test")
val spark = SparkSession
.builder
.appName("GroupBy Test")
.getOrCreate()
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
val sc = new SparkContext(sparkConf)
val sc = spark.sparkContext
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
@ -50,7 +54,7 @@ object GroupByTest {
println(pairs1.groupByKey(numReducers).count())
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -18,7 +18,7 @@
// scalastyle:off println
package org.apache.spark.examples
import org.apache.spark._
import org.apache.spark.sql.SparkSession
object HdfsTest {
@ -29,9 +29,11 @@ object HdfsTest {
System.err.println("Usage: HdfsTest <file>")
System.exit(1)
}
val sparkConf = new SparkConf().setAppName("HdfsTest")
val sc = new SparkContext(sparkConf)
val file = sc.textFile(args(0))
val spark = SparkSession
.builder
.appName("HdfsTest")
.getOrCreate()
val file = spark.read.text(args(0)).rdd
val mapped = file.map(s => s.length).cache()
for (iter <- 1 to 10) {
val start = System.currentTimeMillis()
@ -39,7 +41,7 @@ object HdfsTest {
val end = System.currentTimeMillis()
println("Iteration " + iter + " took " + (end-start) + " ms")
}
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -18,8 +18,9 @@
// scalastyle:off println
package org.apache.spark.examples
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
/**
* Usage: MultiBroadcastTest [slices] [numElem]
@ -27,8 +28,12 @@ import org.apache.spark.rdd.RDD
object MultiBroadcastTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test")
val sc = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.appName("Multi-Broadcast Test")
.getOrCreate()
val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val num = if (args.length > 1) args(1).toInt else 1000000
@ -51,7 +56,7 @@ object MultiBroadcastTest {
// Collect the small RDD so we can print the observed sizes locally.
observedSizes.collect().foreach(i => println(i))
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -20,23 +20,26 @@ package org.apache.spark.examples
import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio]
*/
object SimpleSkewedGroupByTest {
def main(args: Array[String]) {
val spark = SparkSession
.builder
.appName("SimpleSkewedGroupByTest")
.getOrCreate()
val sc = spark.sparkContext
val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest")
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
var ratio = if (args.length > 4) args(4).toInt else 5.0
val sc = new SparkContext(sparkConf)
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var result = new Array[(Int, Array[Byte])](numKVPairs)
@ -64,7 +67,7 @@ object SimpleSkewedGroupByTest {
// .map{case (k,v) => (k, v.size)}
// .collectAsMap)
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -20,20 +20,25 @@ package org.apache.spark.examples
import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
*/
object SkewedGroupByTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("GroupBy Test")
val spark = SparkSession
.builder
.appName("GroupBy Test")
.getOrCreate()
val sc = spark.sparkContext
var numMappers = if (args.length > 0) args(0).toInt else 2
var numKVPairs = if (args.length > 1) args(1).toInt else 1000
var valSize = if (args.length > 2) args(2).toInt else 1000
var numReducers = if (args.length > 3) args(3).toInt else numMappers
val sc = new SparkContext(sparkConf)
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
@ -54,7 +59,7 @@ object SkewedGroupByTest {
println(pairs1.groupByKey(numReducers).count())
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -20,7 +20,7 @@ package org.apache.spark.examples
import org.apache.commons.math3.linear._
import org.apache.spark._
import org.apache.spark.sql.SparkSession
/**
* Alternating least squares matrix factorization.
@ -108,8 +108,12 @@ object SparkALS {
println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
val sparkConf = new SparkConf().setAppName("SparkALS")
val sc = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.appName("SparkALS")
.getOrCreate()
val sc = spark.sparkContext
val R = generateR()
@ -135,7 +139,7 @@ object SparkALS {
println()
}
sc.stop()
spark.stop()
}
private def randomVector(n: Int): RealVector =

View file

@ -23,9 +23,8 @@ import java.util.Random
import scala.math.exp
import breeze.linalg.{DenseVector, Vector}
import org.apache.hadoop.conf.Configuration
import org.apache.spark._
import org.apache.spark.sql.SparkSession
/**
* Logistic regression based classification.
@ -67,11 +66,14 @@ object SparkHdfsLR {
showWarning()
val sparkConf = new SparkConf().setAppName("SparkHdfsLR")
val spark = SparkSession
.builder
.appName("SparkHdfsLR")
.getOrCreate()
val inputPath = args(0)
val conf = new Configuration()
val sc = new SparkContext(sparkConf)
val lines = sc.textFile(inputPath)
val lines = spark.read.text(inputPath).rdd
val points = lines.map(parsePoint).cache()
val ITERATIONS = args(1).toInt
@ -88,7 +90,7 @@ object SparkHdfsLR {
}
println("Final w: " + w)
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -20,7 +20,7 @@ package org.apache.spark.examples
import breeze.linalg.{squaredDistance, DenseVector, Vector}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* K-means clustering.
@ -66,14 +66,17 @@ object SparkKMeans {
showWarning()
val sparkConf = new SparkConf().setAppName("SparkKMeans")
val sc = new SparkContext(sparkConf)
val lines = sc.textFile(args(0))
val spark = SparkSession
.builder
.appName("SparkKMeans")
.getOrCreate()
val lines = spark.read.text(args(0)).rdd
val data = lines.map(parseVector _).cache()
val K = args(1).toInt
val convergeDist = args(2).toDouble
val kPoints = data.takeSample(withReplacement = false, K, 42).toArray
val kPoints = data.takeSample(withReplacement = false, K, 42)
var tempDist = 1.0
while(tempDist > convergeDist) {
@ -97,7 +100,7 @@ object SparkKMeans {
println("Final centers:")
kPoints.foreach(println)
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -24,7 +24,7 @@ import scala.math.exp
import breeze.linalg.{DenseVector, Vector}
import org.apache.spark._
import org.apache.spark.sql.SparkSession
/**
* Logistic regression based classification.
@ -63,8 +63,13 @@ object SparkLR {
showWarning()
val sparkConf = new SparkConf().setAppName("SparkLR")
val sc = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.appName("SparkLR")
.getOrCreate()
val sc = spark.sparkContext
val numSlices = if (args.length > 0) args(0).toInt else 2
val points = sc.parallelize(generateData, numSlices).cache()
@ -82,7 +87,7 @@ object SparkLR {
println("Final w: " + w)
sc.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -18,7 +18,7 @@
// scalastyle:off println
package org.apache.spark.examples
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Computes the PageRank of URLs from an input file. Input file should
@ -50,10 +50,13 @@ object SparkPageRank {
showWarning()
val sparkConf = new SparkConf().setAppName("PageRank")
val spark = SparkSession
.builder
.appName("SparkPageRank")
.getOrCreate()
val iters = if (args.length > 1) args(1).toInt else 10
val ctx = new SparkContext(sparkConf)
val lines = ctx.textFile(args(0), 1)
val lines = spark.read.text(args(0)).rdd
val links = lines.map{ s =>
val parts = s.split("\\s+")
(parts(0), parts(1))
@ -71,7 +74,7 @@ object SparkPageRank {
val output = ranks.collect()
output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + "."))
ctx.stop()
spark.stop()
}
}
// scalastyle:on println

View file

@ -20,16 +20,19 @@ package org.apache.spark.examples
import scala.math.random
import org.apache.spark._
import org.apache.spark.sql.SparkSession
/** Computes an approximation to pi */
object SparkPi {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("Spark Pi")
val spark = new SparkContext(conf)
val spark = SparkSession
.builder
.appName("Spark Pi")
.getOrCreate()
val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
val count = spark.parallelize(1 until n, slices).map { i =>
val count = sc.parallelize(1 until n, slices).map { i =>
val x = random * 2 - 1
val y = random * 2 - 1
if (x*x + y*y < 1) 1 else 0

View file

@ -21,7 +21,7 @@ package org.apache.spark.examples
import scala.collection.mutable
import scala.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
/**
* Transitive closure on a graph.
@ -42,10 +42,13 @@ object SparkTC {
}
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("SparkTC")
val spark = new SparkContext(sparkConf)
val spark = SparkSession
.builder
.appName("SparkTC")
.getOrCreate()
val sc = spark.sparkContext
val slices = if (args.length > 0) args(0).toInt else 2
var tc = spark.parallelize(generateGraph, slices).cache()
var tc = sc.parallelize(generateGraph, slices).cache()
// Linear transitive closure: each round grows paths by one edge,
// by joining the graph's edges with the already-discovered paths.

View file

@ -22,7 +22,7 @@ import java.io.File
import com.google.common.io.{ByteStreams, Files}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkConf
import org.apache.spark.sql._
object HiveFromSpark {

View file

@ -168,8 +168,8 @@ public class JavaDataFrameSuite {
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()),
schema.apply("e"));
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true,
Metadata.empty()), schema.apply("e"));
Row first = df.select("a", "b", "c", "d", "e").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
@ -189,7 +189,7 @@ public class JavaDataFrameSuite {
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
// Java.math.BigInteger is equavient to Spark Decimal(38,0)
// Java.math.BigInteger is equavient to Spark Decimal(38,0)
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
}