Source code for sparkdl.image.imageIO

# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from io import BytesIO
from collections import namedtuple
from warnings import warn

# 3rd party
import numpy as np
from PIL import Image

# pyspark
from pyspark import Row
from pyspark import SparkContext
from pyspark.sql.types import (BinaryType, IntegerType, StringType, StructField, StructType)
from pyspark.sql.functions import udf


imageSchema = StructType([StructField("mode", StringType(), False),
                          StructField("height", IntegerType(), False),
                          StructField("width", IntegerType(), False),
                          StructField("nChannels", IntegerType(), False),
                          StructField("data", BinaryType(), False)])


# ImageType class for holding metadata about images stored in DataFrames.
# fields:
#   nChannels - number of channels in the image
#   dtype - data type of the image's "data" Column, sorted as a numpy compatible string.
#   channelContent - info about the contents of each channel currently only "I" (intensity) and
#     "RGB" are supported for 1 and 3 channel data respectively.
#   pilMode - The mode that should be used to convert to a PIL image.
#   sparkMode - Unique identifier string used in spark image representation.
ImageType = namedtuple("ImageType", ["nChannels",
                                     "dtype",
                                     "channelContent",
                                     "pilMode",
                                     "sparkMode",
                                     ])
class SparkMode(object):
    RGB = "RGB"
    FLOAT32 = "float32"
    RGB_FLOAT32 = "RGB-float32"

supportedImageTypes = [
    ImageType(3, "uint8", "RGB", "RGB", SparkMode.RGB),
    ImageType(1, "float32", "I", "F", SparkMode.FLOAT32),
    ImageType(3, "float32", "RGB", None, SparkMode.RGB_FLOAT32),
]
pilModeLookup = {t.pilMode: t for t in supportedImageTypes
                 if t.pilMode is not None}
sparkModeLookup = {t.sparkMode: t for t in supportedImageTypes}


def imageArrayToStruct(imgArray, sparkMode=None):
    """
    Create a row representation of an image from an image array and (optional) imageType.

    to_image_udf = udf(arrayToImageRow, imageSchema)
    df.withColumn("output_img", to_image_udf(df["np_arr_col"])

    :param imgArray: ndarray, image data.
    :param sparkMode: spark mode, type information for the image, will be inferred from array if
        the mode is not provide. See SparkMode for valid modes.
    :return: Row, image as a DataFrame Row.
    """
    # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
    if len(imgArray.shape) == 4:
        if imgArray.shape[0] != 1:
            raise ValueError("The first dimension of a 4-d image array is expected to be 1.")
        imgArray = imgArray.reshape(imgArray.shape[1:])

    if sparkMode is None:
        sparkMode = _arrayToSparkMode(imgArray)
    imageType = sparkModeLookup[sparkMode]

    height, width, nChannels = imgArray.shape
    if imageType.nChannels != nChannels:
        msg = "Image of type {} should have {} channels, but array has {} channels."
        raise ValueError(msg.format(sparkMode, imageType.nChannels, nChannels))

    # Convert the array to match the image type.
    if not np.can_cast(imgArray, imageType.dtype, 'same_kind'):
        msg = "Array of type {} cannot safely be cast to image type {}."
        raise ValueError(msg.format(imgArray.dtype, imageType.dtype))
    imgArray = np.array(imgArray, dtype=imageType.dtype, copy=False)

    data = bytearray(imgArray.tobytes())
    return Row(mode=sparkMode, height=height, width=width, nChannels=nChannels, data=data)


[docs]def imageType(imageRow): """ Get type information about the image. :param imageRow: spark image row. :return: ImageType """ return sparkModeLookup[imageRow.mode]
def imageStructToArray(imageRow): """ Convert an image to a numpy array. :param imageRow: Row, must use imageSchema. :return: ndarray, image data. """ imType = imageType(imageRow) shape = (imageRow.height, imageRow.width, imageRow.nChannels) return np.ndarray(shape, imType.dtype, imageRow.data) def _arrayToSparkMode(arr): assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(arr.shape) num_channels = arr.shape[2] if num_channels == 1: if arr.dtype not in [np.float16, np.float32, np.float64]: raise ValueError("incompatible dtype (%s) for numpy array for float32 mode" % arr.dtype.string) return SparkMode.FLOAT32 elif num_channels != 3: raise ValueError("number of channels of the input array (%d) is not supported" % num_channels) elif arr.dtype == np.uint8: return SparkMode.RGB elif arr.dtype in [np.float16, np.float32, np.float64]: return SparkMode.RGB_FLOAT32 else: raise ValueError("did not find a sparkMode for the given array with num_channels = %d " + "and dtype %s" % (num_channels, arr.dtype.string)) def _resizeFunction(size): """ Creates a resize function. :param size: tuple, size of new image: (height, width). :return: function: image => image, a function that converts an input image to an image with of `size`. """ if len(size) != 2: raise ValueError("New image size should have for [hight, width] but got {}".format(size)) def resizeImageAsRow(imgAsRow): imgAsArray = imageStructToArray(imgAsRow) imgType = imageType(imgAsRow) imgAsPil = Image.fromarray(imgAsArray, imgType.pilMode) imgAsPil = imgAsPil.resize(size[::-1]) imgAsArray = np.array(imgAsPil) return imageArrayToStruct(imgAsArray, imgType.sparkMode) return resizeImageAsRow def resizeImage(size): """ Create a udf for resizing image. Example usage: dataFrame.select(resizeImage((height, width))('imageColumn')) :param size: tuple, target size of new image in the form (height, width). :return: udf, a udf for resizing an image column to `size`. """ return udf(_resizeFunction(size), imageSchema) def _decodeImage(imageData): """ Decode compressed image data into a DataFrame image row. :param imageData: (bytes, bytearray) compressed image data in PIL compatible format. :return: Row, decoded image. """ try: img = Image.open(BytesIO(imageData)) except IOError: return None if img.mode in pilModeLookup: mode = pilModeLookup[img.mode] else: msg = "We don't currently support images with mode: {mode}" warn(msg.format(mode=img.mode)) return None imgArray = np.asarray(img) image = imageArrayToStruct(imgArray, mode.sparkMode) return image # Creating a UDF on import can cause SparkContext issues sometimes. # decodeImage = udf(_decodeImage, imageSchema) def filesToDF(sc, path, numPartitions=None): """ Read files from a directory to a DataFrame. :param sc: SparkContext. :param path: str, path to files. :param numPartition: int, number or partitions to use for reading files. :return: DataFrame, with columns: (filePath: str, fileData: BinaryType) """ numPartitions = numPartitions or sc.defaultParallelism schema = StructType([StructField("filePath", StringType(), False), StructField("fileData", BinaryType(), False)]) rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions) rdd = rdd.map(lambda x: (x[0], bytearray(x[1]))) return rdd.toDF(schema)
[docs]def readImages(imageDirectory, numPartition=None): """ Read a directory of images (or a single image) into a DataFrame. :param sc: spark context :param imageDirectory: str, file path. :param numPartition: int, number or partitions to use for reading files. :return: DataFrame, with columns: (filepath: str, image: imageSchema). """ return _readImages(imageDirectory, numPartition, SparkContext.getOrCreate())
def _readImages(imageDirectory, numPartition, sc): decodeImage = udf(_decodeImage, imageSchema) imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition) return imageData.select("filePath", decodeImage("fileData").alias("image"))