Initial weights in Scala are ones; do that too. Also fix some errors.
This commit is contained in:
parent
4e821390bc
commit
02208a175c
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape
|
||||
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
|
||||
from pyspark import SparkContext
|
||||
|
||||
# Double vector format:
|
||||
|
@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs):
|
|||
elif (type(x) == RDD):
|
||||
raise RuntimeError("Bulk predict not yet supported.")
|
||||
else:
|
||||
raise TypeError("Argument of type " + type(x) + " unsupported")
|
||||
raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
|
||||
|
||||
def _get_unmangled_rdd(data, serializer):
|
||||
dataBytes = data.map(serializer)
|
||||
|
@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data):
|
|||
initial_weights = data.first()
|
||||
if type(initial_weights) != ndarray:
|
||||
raise TypeError("At least one data element has type "
|
||||
+ type(initial_weights) + " which is not ndarray")
|
||||
+ type(initial_weights).__name__ + " which is not ndarray")
|
||||
if initial_weights.ndim != 1:
|
||||
raise TypeError("At least one data element has "
|
||||
+ initial_weights.ndim + " dimensions, which is not 1")
|
||||
initial_weights = zeros([initial_weights.shape[0] - 1])
|
||||
initial_weights = ones([initial_weights.shape[0] - 1])
|
||||
return initial_weights
|
||||
|
||||
# train_func should take two parameters, namely data and initial_weights, and
|
||||
|
@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
|
|||
raise RuntimeError("JVM call result had unexpected length")
|
||||
elif type(ans[0]) != bytearray:
|
||||
raise RuntimeError("JVM call result had first element of type "
|
||||
+ type(ans[0]) + " which is not bytearray")
|
||||
+ type(ans[0]).__name__ + " which is not bytearray")
|
||||
elif type(ans[1]) != float:
|
||||
raise RuntimeError("JVM call result had second element of type "
|
||||
+ type(ans[0]) + " which is not float")
|
||||
+ type(ans[0]).__name__ + " which is not float")
|
||||
return klass(_deserialize_double_vector(ans[0]), ans[1])
|
||||
|
||||
def _serialize_rating(r):
|
||||
|
|
Loading…
Reference in a new issue