[SPARK-15618][SQL][MLLIB] Use SparkSession.builder.sparkContext if applicable.
## What changes were proposed in this pull request? This PR changes function `SparkSession.builder.sparkContext(..)` from **private[sql]** into **private[spark]**, and uses it if applicable like the followings. ``` - val spark = SparkSession.builder().config(sc.getConf).getOrCreate() + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() ``` ## How was this patch tested? Pass the existing Jenkins tests. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13365 from dongjoon-hyun/SPARK-15618.
This commit is contained in:
parent
93e97147eb
commit
85d6b0db9f
|
@ -29,13 +29,10 @@ object BroadcastTest {
|
|||
|
||||
val blockSize = if (args.length > 2) args(2) else "4096"
|
||||
|
||||
val sparkConf = new SparkConf()
|
||||
.set("spark.broadcast.blockSize", blockSize)
|
||||
|
||||
val spark = SparkSession
|
||||
.builder
|
||||
.config(sparkConf)
|
||||
.builder()
|
||||
.appName("Broadcast Test")
|
||||
.config("spark.broadcast.blockSize", blockSize)
|
||||
.getOrCreate()
|
||||
|
||||
val sc = spark.sparkContext
|
||||
|
|
|
@ -191,6 +191,7 @@ object LDAExample {
|
|||
|
||||
val spark = SparkSession
|
||||
.builder
|
||||
.sparkContext(sc)
|
||||
.getOrCreate()
|
||||
import spark.implicits._
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.io.File
|
|||
|
||||
import com.google.common.io.{ByteStreams, Files}
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.sql._
|
||||
|
||||
object HiveFromSpark {
|
||||
|
@ -35,8 +34,6 @@ object HiveFromSpark {
|
|||
ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File))
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val sparkConf = new SparkConf().setAppName("HiveFromSpark")
|
||||
|
||||
// When working with Hive, one must instantiate `SparkSession` with Hive support, including
|
||||
// connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined
|
||||
// functions. Users who do not have an existing Hive deployment can still enable Hive support.
|
||||
|
@ -45,7 +42,7 @@ object HiveFromSpark {
|
|||
// which defaults to the directory `spark-warehouse` in the current directory that the spark
|
||||
// application is started.
|
||||
val spark = SparkSession.builder
|
||||
.config(sparkConf)
|
||||
.appName("HiveFromSpark")
|
||||
.enableHiveSupport()
|
||||
.getOrCreate()
|
||||
val sc = spark.sparkContext
|
||||
|
|
|
@ -1177,7 +1177,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
// We use DataFrames for serialization of IndexedRows to Python,
|
||||
// so return a DataFrame.
|
||||
val sc = indexedRowMatrix.rows.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
spark.createDataFrame(indexedRowMatrix.rows)
|
||||
}
|
||||
|
||||
|
@ -1188,7 +1188,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
// We use DataFrames for serialization of MatrixEntry entries to
|
||||
// Python, so return a DataFrame.
|
||||
val sc = coordinateMatrix.entries.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
spark.createDataFrame(coordinateMatrix.entries)
|
||||
}
|
||||
|
||||
|
@ -1199,7 +1199,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
// We use DataFrames for serialization of sub-matrix blocks to
|
||||
// Python, so return a DataFrame.
|
||||
val sc = blockMatrix.blocks.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
spark.createDataFrame(blockMatrix.blocks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -437,7 +437,7 @@ class LogisticRegressionWithLBFGS
|
|||
lr.setMaxIter(optimizer.getNumIterations())
|
||||
lr.setTol(optimizer.getConvergenceTol())
|
||||
// Convert our input into a DataFrame
|
||||
val spark = SparkSession.builder().config(input.context.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
|
||||
val df = spark.createDataFrame(input.map(_.asML))
|
||||
// Determine if we should cache the DF
|
||||
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
|
||||
|
|
|
@ -193,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
modelType: String)
|
||||
|
||||
def save(sc: SparkContext, path: String, data: Data): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// Create JSON metadata.
|
||||
val metadata = compact(render(
|
||||
|
@ -207,7 +207,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
|
||||
@Since("1.3.0")
|
||||
def load(sc: SparkContext, path: String): NaiveBayesModel = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
// Load Parquet data.
|
||||
val dataRDD = spark.read.parquet(dataPath(path))
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
@ -238,7 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
theta: Array[Array[Double]])
|
||||
|
||||
def save(sc: SparkContext, path: String, data: Data): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// Create JSON metadata.
|
||||
val metadata = compact(render(
|
||||
|
@ -251,7 +251,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): NaiveBayesModel = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
// Load Parquet data.
|
||||
val dataRDD = spark.read.parquet(dataPath(path))
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
|
|
@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
|
|||
weights: Vector,
|
||||
intercept: Double,
|
||||
threshold: Option[Double]): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// Create JSON metadata.
|
||||
val metadata = compact(render(
|
||||
|
@ -73,7 +73,7 @@ private[classification] object GLMClassificationModel {
|
|||
*/
|
||||
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataRDD = spark.read.parquet(dataPath)
|
||||
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
|
||||
assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
|
||||
|
|
|
@ -144,7 +144,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
|
||||
|
||||
def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
|
||||
~ ("rootId" -> model.root.index)))
|
||||
|
@ -165,7 +165,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
|
|||
}
|
||||
|
||||
def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val rows = spark.read.parquet(Loader.dataPath(path))
|
||||
Loader.checkSchema[Data](rows.schema)
|
||||
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
|
||||
|
|
|
@ -143,7 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
|
|||
path: String,
|
||||
weights: Array[Double],
|
||||
gaussians: Array[MultivariateGaussian]): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// Create JSON metadata.
|
||||
val metadata = compact(render
|
||||
|
@ -159,7 +159,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataFrame = spark.read.parquet(dataPath)
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
|
|
@ -123,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] {
|
|||
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
|
||||
|
||||
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
|
||||
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
|
||||
|
@ -135,7 +135,7 @@ object KMeansModel extends Loader[KMeansModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): KMeansModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -446,7 +446,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
|||
docConcentration: Vector,
|
||||
topicConcentration: Double,
|
||||
gammaShape: Double): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val k = topicsMatrix.numCols
|
||||
val metadata = compact(render
|
||||
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||
|
@ -470,7 +470,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
|
|||
topicConcentration: Double,
|
||||
gammaShape: Double): LocalLDAModel = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataFrame = spark.read.parquet(dataPath)
|
||||
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
@ -851,7 +851,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
|||
topicConcentration: Double,
|
||||
iterationTimes: Array[Double],
|
||||
gammaShape: Double): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render
|
||||
(("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||
|
@ -887,7 +887,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
|
|||
val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
|
||||
val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
|
||||
val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataFrame = spark.read.parquet(dataPath)
|
||||
val vertexDataFrame = spark.read.parquet(vertexDataPath)
|
||||
val edgeDataFrame = spark.read.parquet(edgeDataPath)
|
||||
|
|
|
@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
|
|||
|
||||
@Since("1.4.0")
|
||||
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
|
||||
|
@ -82,7 +82,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
|
|||
@Since("1.4.0")
|
||||
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
|
|
|
@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
|
||||
|
||||
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
|
||||
|
@ -149,7 +149,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -609,7 +609,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
|
|||
case class Data(word: String, vector: Array[Float])
|
||||
|
||||
def load(sc: SparkContext, path: String): Word2VecModel = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataFrame = spark.read.parquet(Loader.dataPath(path))
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
@ -620,7 +620,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
|
|||
}
|
||||
|
||||
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val vectorSize = model.values.head.length
|
||||
val numWords = model.size
|
||||
|
|
|
@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
|
|||
|
||||
def save(model: FPGrowthModel[_], path: String): Unit = {
|
||||
val sc = model.freqItemsets.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
|
||||
|
@ -123,7 +123,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
|
|||
|
||||
def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
|
|
|
@ -616,7 +616,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
|
|||
|
||||
def save(model: PrefixSpanModel[_], path: String): Unit = {
|
||||
val sc = model.freqSequences.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
|
||||
|
@ -640,7 +640,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
|
|||
|
||||
def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
|
|
|
@ -354,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
|
|||
*/
|
||||
def save(model: MatrixFactorizationModel, path: String): Unit = {
|
||||
val sc = model.userFeatures.sparkContext
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
import spark.implicits._
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
|
||||
|
@ -365,7 +365,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val (className, formatVersion, metadata) = loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
|
|||
boundaries: Array[Double],
|
||||
predictions: Array[Double],
|
||||
isotonic: Boolean): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||
|
@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
|
|||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataRDD = spark.read.parquet(dataPath(path))
|
||||
|
||||
checkSchema[Data](dataRDD.schema)
|
||||
|
|
|
@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
|
|||
modelClass: String,
|
||||
weights: Vector,
|
||||
intercept: Double): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// Create JSON metadata.
|
||||
val metadata = compact(render(
|
||||
|
@ -68,7 +68,7 @@ private[regression] object GLMRegressionModel {
|
|||
*/
|
||||
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataRDD = spark.read.parquet(dataPath)
|
||||
val dataArray = dataRDD.select("weights", "intercept").take(1)
|
||||
assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath")
|
||||
|
|
|
@ -233,13 +233,13 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
|
|||
// Create Parquet data.
|
||||
val nodes = model.topNode.subtreeIterator.toSeq
|
||||
val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _))
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
|
||||
}
|
||||
|
||||
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
|
||||
// Load Parquet data.
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val dataRDD = spark.read.parquet(dataPath)
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
|
|
@ -413,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging {
|
|||
case class EnsembleNodeData(treeId: Int, node: NodeData)
|
||||
|
||||
def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
|
||||
// SPARK-6120: We do a hacky check here so users understand why save() is failing
|
||||
// when they run the ML guide example.
|
||||
|
@ -471,7 +471,7 @@ private[tree] object TreeEnsembleModel extends Logging {
|
|||
sc: SparkContext,
|
||||
path: String,
|
||||
treeAlgo: String): Array[DecisionTreeModel] = {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
|
||||
import spark.implicits._
|
||||
val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply)
|
||||
val trees = constructTrees(nodes.rdd)
|
||||
|
|
|
@ -23,18 +23,14 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
|||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.feature
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.Row
|
||||
|
||||
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||
with DefaultReadWriteTest {
|
||||
|
||||
test("Test Chi-Square selector") {
|
||||
val spark = SparkSession.builder
|
||||
.master("local[2]")
|
||||
.appName("ChiSqSelectorSuite")
|
||||
.getOrCreate()
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
||||
val data = Seq(
|
||||
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
|
||||
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
|
||||
|
|
|
@ -27,7 +27,7 @@ class QuantileDiscretizerSuite
|
|||
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
|
||||
|
||||
test("Test observed number of buckets and their sizes match expected values") {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
||||
val datasetSize = 100000
|
||||
|
@ -53,7 +53,7 @@ class QuantileDiscretizerSuite
|
|||
}
|
||||
|
||||
test("Test transform method on unseen data") {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
||||
val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
|
||||
|
@ -82,7 +82,7 @@ class QuantileDiscretizerSuite
|
|||
}
|
||||
|
||||
test("Verify resulting model has parent") {
|
||||
val spark = SparkSession.builder().config(sc.getConf).getOrCreate()
|
||||
val spark = this.spark
|
||||
import spark.implicits._
|
||||
|
||||
val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
|
||||
|
|
|
@ -591,6 +591,7 @@ class ALSCleanerSuite extends SparkFunSuite {
|
|||
val spark = SparkSession.builder
|
||||
.master("local[2]")
|
||||
.appName("ALSCleanerSuite")
|
||||
.sparkContext(sc)
|
||||
.getOrCreate()
|
||||
import spark.implicits._
|
||||
val als = new ALS()
|
||||
|
@ -606,7 +607,7 @@ class ALSCleanerSuite extends SparkFunSuite {
|
|||
val pattern = "shuffle_(\\d+)_.+\\.data".r
|
||||
val rddIds = resultingFiles.flatMap { f =>
|
||||
pattern.findAllIn(f.getName()).matchData.map { _.group(1) } }
|
||||
assert(rddIds.toSet.size === 4)
|
||||
assert(rddIds.size === 4)
|
||||
} finally {
|
||||
sc.stop()
|
||||
}
|
||||
|
|
|
@ -42,9 +42,10 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
data: RDD[LabeledPoint],
|
||||
categoricalFeatures: Map[Int, Int],
|
||||
numClasses: Int): DataFrame = {
|
||||
val spark = SparkSession.builder
|
||||
val spark = SparkSession.builder()
|
||||
.master("local[2]")
|
||||
.appName("TreeTests")
|
||||
.sparkContext(data.sparkContext)
|
||||
.getOrCreate()
|
||||
import spark.implicits._
|
||||
|
||||
|
|
|
@ -694,7 +694,7 @@ object SparkSession {
|
|||
|
||||
private[this] var userSuppliedContext: Option[SparkContext] = None
|
||||
|
||||
private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
|
||||
private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized {
|
||||
userSuppliedContext = Option(sparkContext)
|
||||
this
|
||||
}
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
|
||||
import org.apache.spark.AccumulatorSuite
|
||||
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
|
||||
import org.apache.spark.sql.execution.exchange.EnsureRequirements
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
|
@ -44,11 +44,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
*/
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local-cluster[2,1,1024]")
|
||||
.setAppName("testing")
|
||||
val sc = new SparkContext(conf)
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
spark = SparkSession.builder()
|
||||
.master("local-cluster[2,1,1024]")
|
||||
.appName("testing")
|
||||
.getOrCreate()
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._
|
|||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
|
||||
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
|
||||
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource}
|
||||
import org.apache.spark.sql.expressions.Window
|
||||
|
@ -282,15 +282,12 @@ object SetWarehouseLocationTest extends Logging {
|
|||
val hiveWarehouseLocation = Utils.createTempDir()
|
||||
hiveWarehouseLocation.delete()
|
||||
|
||||
val conf = new SparkConf()
|
||||
conf.set("spark.ui.enabled", "false")
|
||||
// We will use the value of spark.sql.warehouse.dir override the
|
||||
// value of hive.metastore.warehouse.dir.
|
||||
conf.set("spark.sql.warehouse.dir", warehouseLocation.toString)
|
||||
conf.set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString)
|
||||
|
||||
val sparkSession = SparkSession.builder
|
||||
.config(conf)
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config("spark.ui.enabled", "false")
|
||||
.config("spark.sql.warehouse.dir", warehouseLocation.toString)
|
||||
.config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString)
|
||||
.enableHiveSupport()
|
||||
.getOrCreate()
|
||||
|
||||
|
|
Loading…
Reference in a new issue