[SPARK-15738][PYSPARK][ML] Adding Pyspark ml RFormula __str__ method similar to Scala API
## What changes were proposed in this pull request? Adding __str__ to RFormula and model that will show the set formula param and resolved formula. This is currently present in the Scala API, found missing in PySpark during Spark 2.0 coverage review. ## How was this patch tested? run pyspark-ml tests locally Author: Bryan Cutler <cutlerb@gmail.com> Closes #13481 from BryanCutler/pyspark-ml-rformula_str-SPARK-15738.
This commit is contained in:
parent
254bc8c34e
commit
7d7a0a5e07
|
@ -182,7 +182,7 @@ class RFormula(override val uid: String)
|
|||
|
||||
override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
|
||||
|
||||
override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
|
||||
override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
|
|
@ -126,7 +126,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
|
|||
* @param hasIntercept whether the formula specifies fitting with an intercept.
|
||||
*/
|
||||
private[ml] case class ResolvedRFormula(
|
||||
label: String, terms: Seq[Seq[String]], hasIntercept: Boolean)
|
||||
label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) {
|
||||
|
||||
override def toString: String = {
|
||||
val ts = terms.map {
|
||||
case t if t.length > 1 =>
|
||||
s"${t.mkString("{", ",", "}")}"
|
||||
case t =>
|
||||
t.mkString
|
||||
}
|
||||
val termStr = ts.mkString("[", ",", "]")
|
||||
s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* R formula terms. See the R formula docs here for more information:
|
||||
|
|
|
@ -2528,6 +2528,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|
|||
True
|
||||
>>> loadedRF.getLabelCol() == rf.getLabelCol()
|
||||
True
|
||||
>>> str(loadedRF)
|
||||
'RFormula(y ~ x + s) (uid=...)'
|
||||
>>> modelPath = temp_path + "/rFormulaModel"
|
||||
>>> model.save(modelPath)
|
||||
>>> loadedModel = RFormulaModel.load(modelPath)
|
||||
|
@ -2542,6 +2544,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|
|||
|0.0|0.0| a|[0.0,1.0]| 0.0|
|
||||
+---+---+---+---------+-----+
|
||||
...
|
||||
>>> str(loadedModel)
|
||||
'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)'
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
"""
|
||||
|
@ -2586,6 +2590,10 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
|
|||
def _create_model(self, java_model):
|
||||
return RFormulaModel(java_model)
|
||||
|
||||
def __str__(self):
|
||||
formulaStr = self.getFormula() if self.isDefined(self.formula) else ""
|
||||
return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)
|
||||
|
||||
|
||||
class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
|
||||
"""
|
||||
|
@ -2597,6 +2605,10 @@ class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
|
|||
.. versionadded:: 1.5.0
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
resolvedFormula = self._call_java("resolvedFormula")
|
||||
return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable,
|
||||
|
|
Loading…
Reference in a new issue