[SPARK-6120] [mllib] Warnings about memory in tree, ensemble model save

Issue: When the Python DecisionTree example in the programming guide is run, it runs out of Java Heap Space when using the default memory settings for the spark shell.

This prints a warning.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4864 from jkbradley/dt-save-heap and squashes the following commits:

02e8daf [Joseph K. Bradley] fixed based on code review
7ecb1ed [Joseph K. Bradley] Added warnings about memory when calling tree and ensemble model save with too small a Java heap size
This commit is contained in:
Joseph K. Bradley 2015-03-02 22:33:51 -08:00 committed by Xiangrui Meng
parent 7e53a79c30
commit c2fe3a6ff1
2 changed files with 50 additions and 4 deletions

View file

@ -23,7 +23,7 @@ import org.json4s._
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.util.Utils
/** /**
* :: Experimental :: * :: Experimental ::
@ -115,7 +116,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
override protected def formatVersion: String = "1.0" override protected def formatVersion: String = "1.0"
} }
object DecisionTreeModel extends Loader[DecisionTreeModel] { object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
private[tree] object SaveLoadV1_0 { private[tree] object SaveLoadV1_0 {
@ -187,6 +188,28 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val sqlContext = new SQLContext(sc) val sqlContext = new SQLContext(sc)
import sqlContext.implicits._ import sqlContext.implicits._
// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
// TODO: Fix this issue for real.
val memThreshold = 768
if (sc.isLocal) {
val driverMemory = sc.getConf.getOption("spark.driver.memory")
.orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
if (driverMemory <= memThreshold) {
logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
s" driver memory (${driverMemory}m)." +
s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
}
} else {
if (sc.executorMemory <= memThreshold) {
logWarning(s"$thisClassName.save() was called, but it may fail because of too little" +
s" executor memory (${sc.executorMemory}m)." +
s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
}
}
// Create JSON metadata. // Create JSON metadata.
val metadata = compact(render( val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~

View file

@ -24,7 +24,7 @@ import org.json4s._
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
/** /**
* :: Experimental :: * :: Experimental ::
@ -250,7 +251,7 @@ private[tree] sealed class TreeEnsembleModel(
def totalNumNodes: Int = trees.map(_.numNodes).sum def totalNumNodes: Int = trees.map(_.numNodes).sum
} }
private[tree] object TreeEnsembleModel { private[tree] object TreeEnsembleModel extends Logging {
object SaveLoadV1_0 { object SaveLoadV1_0 {
@ -277,6 +278,28 @@ private[tree] object TreeEnsembleModel {
val sqlContext = new SQLContext(sc) val sqlContext = new SQLContext(sc)
import sqlContext.implicits._ import sqlContext.implicits._
// SPARK-6120: We do a hacky check here so users understand why save() is failing
// when they run the ML guide example.
// TODO: Fix this issue for real.
val memThreshold = 768
if (sc.isLocal) {
val driverMemory = sc.getConf.getOption("spark.driver.memory")
.orElse(Option(System.getenv("SPARK_DRIVER_MEMORY")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
if (driverMemory <= memThreshold) {
logWarning(s"$className.save() was called, but it may fail because of too little" +
s" driver memory (${driverMemory}m)." +
s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).")
}
} else {
if (sc.executorMemory <= memThreshold) {
logWarning(s"$className.save() was called, but it may fail because of too little" +
s" executor memory (${sc.executorMemory}m)." +
s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).")
}
}
// Create JSON metadata. // Create JSON metadata.
implicit val format = DefaultFormats implicit val format = DefaultFormats
val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString,