[SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib
Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley <joseph@databricks.com> Closes #10161 from jkbradley/mllib-sqlcontext-fix.
This commit is contained in:
parent
36282f78b8
commit
3e7e05f5ee
|
@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
|
||||
// We use DataFrames for serialization of IndexedRows to Python,
|
||||
// so return a DataFrame.
|
||||
val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
|
||||
val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext)
|
||||
sqlContext.createDataFrame(indexedRowMatrix.rows)
|
||||
}
|
||||
|
||||
|
@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
|
||||
// We use DataFrames for serialization of MatrixEntry entries to
|
||||
// Python, so return a DataFrame.
|
||||
val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
|
||||
val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext)
|
||||
sqlContext.createDataFrame(coordinateMatrix.entries)
|
||||
}
|
||||
|
||||
|
@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
|
||||
// We use DataFrames for serialization of sub-matrix blocks to
|
||||
// Python, so return a DataFrame.
|
||||
val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
|
||||
val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext)
|
||||
sqlContext.createDataFrame(blockMatrix.blocks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
modelType: String)
|
||||
|
||||
def save(sc: SparkContext, path: String, data: Data): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Create JSON metadata.
|
||||
|
@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
|
||||
@Since("1.3.0")
|
||||
def load(sc: SparkContext, path: String): NaiveBayesModel = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
// Load Parquet data.
|
||||
val dataRDD = sqlContext.read.parquet(dataPath(path))
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
theta: Array[Array[Double]])
|
||||
|
||||
def save(sc: SparkContext, path: String, data: Data): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Create JSON metadata.
|
||||
|
@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
|
|||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): NaiveBayesModel = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
// Load Parquet data.
|
||||
val dataRDD = sqlContext.read.parquet(dataPath(path))
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
|
|
@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel {
|
|||
weights: Vector,
|
||||
intercept: Double,
|
||||
threshold: Option[Double]): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Create JSON metadata.
|
||||
|
@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel {
|
|||
*/
|
||||
def loadData(sc: SparkContext, path: String, modelClass: String): Data = {
|
||||
val datapath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataRDD = sqlContext.read.parquet(datapath)
|
||||
val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1)
|
||||
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
|
||||
|
|
|
@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
|
|||
weights: Array[Double],
|
||||
gaussians: Array[MultivariateGaussian]): Unit = {
|
||||
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Create JSON metadata.
|
||||
|
@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataFrame = sqlContext.read.parquet(dataPath)
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
|
|
@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] {
|
|||
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
|
||||
|
||||
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
|
||||
|
@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): KMeansModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
|
|||
|
||||
@Since("1.4.0")
|
||||
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val metadata = compact(render(
|
||||
|
@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode
|
|||
@Since("1.4.0")
|
||||
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
|
|
|
@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
|
||||
|
||||
def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
|
||||
|
@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): Word2VecModel = {
|
||||
val dataPath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataFrame = sqlContext.read.parquet(dataPath)
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] {
|
|||
|
||||
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
|
||||
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val vectorSize = model.values.head.size
|
||||
|
|
|
@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
|
|||
*/
|
||||
def save(model: MatrixFactorizationModel, path: String): Unit = {
|
||||
val sc = model.userFeatures.sparkContext
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
|
||||
|
@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
|
|||
|
||||
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
|
||||
implicit val formats = DefaultFormats
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val (className, formatVersion, metadata) = loadMetadata(sc, path)
|
||||
assert(className == thisClassName)
|
||||
assert(formatVersion == thisFormatVersion)
|
||||
|
|
|
@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
|
|||
boundaries: Array[Double],
|
||||
predictions: Array[Double],
|
||||
isotonic: Boolean): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
|
||||
val metadata = compact(render(
|
||||
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
|
||||
|
@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
|
|||
}
|
||||
|
||||
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataRDD = sqlContext.read.parquet(dataPath(path))
|
||||
|
||||
checkSchema[Data](dataRDD.schema)
|
||||
|
|
|
@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel {
|
|||
modelClass: String,
|
||||
weights: Vector,
|
||||
intercept: Double): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Create JSON metadata.
|
||||
|
@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel {
|
|||
*/
|
||||
def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = {
|
||||
val datapath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val dataRDD = sqlContext.read.parquet(datapath)
|
||||
val dataArray = dataRDD.select("weights", "intercept").take(1)
|
||||
assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath")
|
||||
|
|
|
@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
|
|||
}
|
||||
|
||||
def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// SPARK-6120: We do a hacky check here so users understand why save() is failing
|
||||
|
@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
|
|||
|
||||
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
|
||||
val datapath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
// Load Parquet data.
|
||||
val dataRDD = sqlContext.read.parquet(datapath)
|
||||
// Check schema explicitly since erasure makes it hard to use match-case for checking.
|
||||
|
|
|
@ -408,7 +408,7 @@ private[tree] object TreeEnsembleModel extends Logging {
|
|||
case class EnsembleNodeData(treeId: Int, node: NodeData)
|
||||
|
||||
def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// SPARK-6120: We do a hacky check here so users understand why save() is failing
|
||||
|
@ -468,7 +468,7 @@ private[tree] object TreeEnsembleModel extends Logging {
|
|||
path: String,
|
||||
treeAlgo: String): Array[DecisionTreeModel] = {
|
||||
val datapath = Loader.dataPath(path)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)
|
||||
val trees = constructTrees(nodes)
|
||||
trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
|
||||
|
|
Loading…
Reference in a new issue