[SPARK-21866][ML][PYTHON][FOLLOWUP] Few cleanups and fix image test failure in Python 3.6.0 / NumPy 1.13.3
## What changes were proposed in this pull request?
Image test seems failed in Python 3.6.0 / NumPy 1.13.3. I manually tested as below:
```
======================================================================
ERROR: test_read_images (pyspark.ml.tests.ImageReaderTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/.../spark/python/pyspark/ml/tests.py", line 1831, in test_read_images
self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
File "/.../spark/python/pyspark/ml/image.py", line 149, in toImage
data = bytearray(array.astype(dtype=np.uint8).ravel())
TypeError: only integer scalar arrays can be converted to a scalar index
----------------------------------------------------------------------
Ran 1 test in 7.606s
```
To be clear, I think the error seems from NumPy - 75b2d5d427/numpy/core/src/multiarray/number.c (L947)
For a smaller scope:
```python
>>> import numpy as np
>>> bytearray(np.array([1]).astype(dtype=np.uint8))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: only integer scalar arrays can be converted to a scalar index
```
In Python 2.7 / NumPy 1.13.1, it prints:
```
bytearray(b'\x01')
```
So, here, I simply worked around it by converting it to bytes as below:
```python
>>> bytearray(np.array([1]).astype(dtype=np.uint8).tobytes())
bytearray(b'\x01')
```
Also, while looking into it again, I realised few arguments could be quite confusing, for example, `Row` that needs some specific attributes and `numpy.ndarray`. I added few type checking and added some tests accordingly. So, it shows an error message as below:
```
TypeError: array argument should be numpy.ndarray; however, it got [<class 'str'>].
```
## How was this patch tested?
Manually tested with `./python/run-tests`.
And also:
```
PYSPARK_PYTHON=python3 SPARK_TESTING=1 bin/pyspark pyspark.ml.tests ImageReaderTest
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #19835 from HyukjinKwon/SPARK-21866-followup.
This commit is contained in:
parent
ab6f60c4d6
commit
92cfbeeb5c
|
@ -108,12 +108,23 @@ class _ImageSchema(object):
|
|||
"""
|
||||
Converts an image to an array with metadata.
|
||||
|
||||
:param image: The image to be converted.
|
||||
:param `Row` image: A row that contains the image to be converted. It should
|
||||
have the attributes specified in `ImageSchema.imageSchema`.
|
||||
:return: a `numpy.ndarray` that is an image.
|
||||
|
||||
.. versionadded:: 2.3.0
|
||||
"""
|
||||
|
||||
if not isinstance(image, Row):
|
||||
raise TypeError(
|
||||
"image argument should be pyspark.sql.types.Row; however, "
|
||||
"it got [%s]." % type(image))
|
||||
|
||||
if any(not hasattr(image, f) for f in self.imageFields):
|
||||
raise ValueError(
|
||||
"image argument should have attributes specified in "
|
||||
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
|
||||
|
||||
height = image.height
|
||||
width = image.width
|
||||
nChannels = image.nChannels
|
||||
|
@ -127,15 +138,20 @@ class _ImageSchema(object):
|
|||
"""
|
||||
Converts an array with metadata to a two-dimensional image.
|
||||
|
||||
:param array array: The array to convert to image.
|
||||
:param `numpy.ndarray` array: The array to convert to image.
|
||||
:param str origin: Path to the image, optional.
|
||||
:return: a :class:`Row` that is a two dimensional image.
|
||||
|
||||
.. versionadded:: 2.3.0
|
||||
"""
|
||||
|
||||
if not isinstance(array, np.ndarray):
|
||||
raise TypeError(
|
||||
"array argument should be numpy.ndarray; however, it got [%s]." % type(array))
|
||||
|
||||
if array.ndim != 3:
|
||||
raise ValueError("Invalid array shape")
|
||||
|
||||
height, width, nChannels = array.shape
|
||||
ocvTypes = ImageSchema.ocvTypes
|
||||
if nChannels == 1:
|
||||
|
@ -146,7 +162,12 @@ class _ImageSchema(object):
|
|||
mode = ocvTypes["CV_8UC4"]
|
||||
else:
|
||||
raise ValueError("Invalid number of channels")
|
||||
data = bytearray(array.astype(dtype=np.uint8).ravel())
|
||||
|
||||
# Running `bytearray(numpy.array([1]))` fails in specific Python versions
|
||||
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
|
||||
# Here, it avoids it by converting it to bytes.
|
||||
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
|
||||
|
||||
# Creating new Row with _create_row(), because Row(name = value, ... )
|
||||
# orders fields by name, which conflicts with expected schema order
|
||||
# when the new DataFrame is created by UDF
|
||||
|
|
|
@ -71,7 +71,7 @@ from pyspark.sql import DataFrame, Row, SparkSession
|
|||
from pyspark.sql.functions import rand
|
||||
from pyspark.sql.types import DoubleType, IntegerType
|
||||
from pyspark.storagelevel import *
|
||||
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
|
||||
from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase
|
||||
|
||||
ser = PickleSerializer()
|
||||
|
||||
|
@ -1836,6 +1836,24 @@ class ImageReaderTest(SparkSessionTestCase):
|
|||
self.assertEqual(ImageSchema.imageFields, expected)
|
||||
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"image argument should be pyspark.sql.types.Row; however",
|
||||
lambda: ImageSchema.toNDArray("a"))
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"image argument should have attributes specified in",
|
||||
lambda: ImageSchema.toNDArray(Row(a=1)))
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"array argument should be numpy.ndarray; however, it got",
|
||||
lambda: ImageSchema.toImage("a"))
|
||||
|
||||
|
||||
class ALSTest(SparkSessionTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue