# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from io import BytesIO
from collections import namedtuple
from warnings import warn
# 3rd party
import numpy as np
from PIL import Image
# pyspark
from pyspark import Row
from pyspark import SparkContext
from pyspark.sql.types import (BinaryType, IntegerType, StringType, StructField, StructType)
from pyspark.sql.functions import udf
imageSchema = StructType([StructField("mode", StringType(), False),
                          StructField("height", IntegerType(), False),
                          StructField("width", IntegerType(), False),
                          StructField("nChannels", IntegerType(), False),
                          StructField("data", BinaryType(), False)])
# ImageType class for holding metadata about images stored in DataFrames.
# fields:
#   nChannels - number of channels in the image
#   dtype - data type of the image's "data" Column, sorted as a numpy compatible string.
#   channelContent - info about the contents of each channel currently only "I" (intensity) and
#     "RGB" are supported for 1 and 3 channel data respectively.
#   pilMode - The mode that should be used to convert to a PIL image.
#   sparkMode - Unique identifier string used in spark image representation.
ImageType = namedtuple("ImageType", ["nChannels",
                                     "dtype",
                                     "channelContent",
                                     "pilMode",
                                     "sparkMode",
                                     ])
class SparkMode(object):
    RGB = "RGB"
    FLOAT32 = "float32"
    RGB_FLOAT32 = "RGB-float32"
supportedImageTypes = [
    ImageType(3, "uint8", "RGB", "RGB", SparkMode.RGB),
    ImageType(1, "float32", "I", "F", SparkMode.FLOAT32),
    ImageType(3, "float32", "RGB", None, SparkMode.RGB_FLOAT32),
]
pilModeLookup = {t.pilMode: t for t in supportedImageTypes
                 if t.pilMode is not None}
sparkModeLookup = {t.sparkMode: t for t in supportedImageTypes}
def imageArrayToStruct(imgArray, sparkMode=None):
    """
    Create a row representation of an image from an image array and (optional) imageType.
    to_image_udf = udf(arrayToImageRow, imageSchema)
    df.withColumn("output_img", to_image_udf(df["np_arr_col"])
    :param imgArray: ndarray, image data.
    :param sparkMode: spark mode, type information for the image, will be inferred from array if
        the mode is not provide. See SparkMode for valid modes.
    :return: Row, image as a DataFrame Row.
    """
    # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
    if len(imgArray.shape) == 4:
        if imgArray.shape[0] != 1:
            raise ValueError("The first dimension of a 4-d image array is expected to be 1.")
        imgArray = imgArray.reshape(imgArray.shape[1:])
    if sparkMode is None:
        sparkMode = _arrayToSparkMode(imgArray)
    imageType = sparkModeLookup[sparkMode]
    height, width, nChannels = imgArray.shape
    if imageType.nChannels != nChannels:
        msg = "Image of type {} should have {} channels, but array has {} channels."
        raise ValueError(msg.format(sparkMode, imageType.nChannels, nChannels))
    # Convert the array to match the image type.
    if not np.can_cast(imgArray, imageType.dtype, 'same_kind'):
        msg = "Array of type {} cannot safely be cast to image type {}."
        raise ValueError(msg.format(imgArray.dtype, imageType.dtype))
    imgArray = np.array(imgArray, dtype=imageType.dtype, copy=False)
    data = bytearray(imgArray.tobytes())
    return Row(mode=sparkMode, height=height, width=width, nChannels=nChannels, data=data)
[docs]def imageType(imageRow):
    """
    Get type information about the image.
    :param imageRow: spark image row.
    :return: ImageType
    """
    return sparkModeLookup[imageRow.mode] 
def imageStructToArray(imageRow):
    """
    Convert an image to a numpy array.
    :param imageRow: Row, must use imageSchema.
    :return: ndarray, image data.
    """
    imType = imageType(imageRow)
    shape = (imageRow.height, imageRow.width, imageRow.nChannels)
    return np.ndarray(shape, imType.dtype, imageRow.data)
def _arrayToSparkMode(arr):
    assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(arr.shape)
    num_channels = arr.shape[2]
    if num_channels == 1:
        if arr.dtype not in [np.float16, np.float32, np.float64]:
            raise ValueError("incompatible dtype (%s) for numpy array for float32 mode" %
                             arr.dtype.string)
        return SparkMode.FLOAT32
    elif num_channels != 3:
        raise ValueError("number of channels of the input array (%d) is not supported" %
                         num_channels)
    elif arr.dtype == np.uint8:
        return SparkMode.RGB
    elif arr.dtype in [np.float16, np.float32, np.float64]:
        return SparkMode.RGB_FLOAT32
    else:
        raise ValueError("did not find a sparkMode for the given array with num_channels = %d " +
                         "and dtype %s" % (num_channels, arr.dtype.string))
def _resizeFunction(size):
    """ Creates a resize function.
    
    :param size: tuple, size of new image: (height, width). 
    :return: function: image => image, a function that converts an input image to an image with 
    of `size`.
    """
    if len(size) != 2:
        raise ValueError("New image size should have for [hight, width] but got {}".format(size))
    def resizeImageAsRow(imgAsRow):
        imgAsArray = imageStructToArray(imgAsRow)
        imgType = imageType(imgAsRow)
        imgAsPil = Image.fromarray(imgAsArray, imgType.pilMode)
        imgAsPil = imgAsPil.resize(size[::-1])
        imgAsArray = np.array(imgAsPil)
        return imageArrayToStruct(imgAsArray, imgType.sparkMode)
    return resizeImageAsRow
def resizeImage(size):
    """ Create a udf for resizing image.
    
    Example usage:
    dataFrame.select(resizeImage((height, width))('imageColumn'))
    
    :param size: tuple, target size of new image in the form (height, width). 
    :return: udf, a udf for resizing an image column to `size`.
    """
    return udf(_resizeFunction(size), imageSchema)
def _decodeImage(imageData):
    """
    Decode compressed image data into a DataFrame image row.
    :param imageData: (bytes, bytearray) compressed image data in PIL compatible format.
    :return: Row, decoded image.
    """
    try:
        img = Image.open(BytesIO(imageData))
    except IOError:
        return None
    if img.mode in pilModeLookup:
        mode = pilModeLookup[img.mode]
    else:
        msg = "We don't currently support images with mode: {mode}"
        warn(msg.format(mode=img.mode))
        return None
    imgArray = np.asarray(img)
    image = imageArrayToStruct(imgArray, mode.sparkMode)
    return image
# Creating a UDF on import can cause SparkContext issues sometimes.
# decodeImage = udf(_decodeImage, imageSchema)
def filesToDF(sc, path, numPartitions=None):
    """
    Read files from a directory to a DataFrame.
    :param sc: SparkContext.
    :param path: str, path to files.
    :param numPartition: int, number or partitions to use for reading files.
    :return: DataFrame, with columns: (filePath: str, fileData: BinaryType)
    """
    numPartitions = numPartitions or sc.defaultParallelism
    schema = StructType([StructField("filePath", StringType(), False),
                         StructField("fileData", BinaryType(), False)])
    rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions)
    rdd = rdd.map(lambda x: (x[0], bytearray(x[1])))
    return rdd.toDF(schema)
[docs]def readImages(imageDirectory, numPartition=None):
    """
    Read a directory of images (or a single image) into a DataFrame.
    :param sc: spark context
    :param imageDirectory: str, file path.
    :param numPartition: int, number or partitions to use for reading files.
    :return: DataFrame, with columns: (filepath: str, image: imageSchema).
    """
    return _readImages(imageDirectory, numPartition, SparkContext.getOrCreate()) 
def _readImages(imageDirectory, numPartition, sc):
    decodeImage = udf(_decodeImage, imageSchema)
    imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition)
    return imageData.select("filePath", decodeImage("fileData").alias("image"))