Source code for sparkdl.transformers.keras_image

# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import keras.backend as K
from keras.models import load_model

from pyspark.ml import Transformer
from pyspark.ml.param import Params, TypeConverters

import sparkdl.graph.utils as tfx
from sparkdl.transformers.keras_utils import KSessionWrap
from sparkdl.param import (
    keyword_only, HasInputCol, HasOutputCol,
    CanLoadImage, HasKerasModel, HasOutputMode)
from sparkdl.transformers.tf_image import TFImageTransformer


[docs]class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol, CanLoadImage, HasKerasModel, HasOutputMode): """ Applies the Tensorflow-backed Keras model (specified by a file name) to images (specified by the URI in the inputCol column) in the DataFrame. Restrictions of the current API: * see TFImageTransformer. * Only supports Tensorflow-backed Keras models (no Theano). """ @keyword_only def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, outputMode="vector"): """ __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, outputMode="vector") """ super(KerasImageFileTransformer, self).__init__() kwargs = self._input_kwargs self.setParams(**kwargs) self._inputTensor = None self._outputTensor = None
[docs] @keyword_only def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, outputMode="vector"): """ setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, outputMode="vector") """ kwargs = self._input_kwargs self._set(**kwargs) return self
def _transform(self, dataset): graph = self._loadTFGraph() image_df = self.loadImagesInternal(dataset, self.getInputCol()) assert self._inputTensor is not None, "self._inputTensor must be set" assert self._outputTensor is not None, "self._outputTensor must be set" transformer = TFImageTransformer(inputCol=self._loadedImageCol(), outputCol=self.getOutputCol(), graph=graph, inputTensor=self._inputTensor, outputTensor=self._outputTensor, outputMode=self.getOrDefault(self.outputMode)) return transformer.transform(image_df).drop(self._loadedImageCol()) def _loadTFGraph(self): with KSessionWrap() as (sess, g): assert K.backend() == "tensorflow", \ "Keras backend is not tensorflow but KerasImageTransformer only supports " + \ "tensorflow-backed Keras models." with g.as_default(): K.set_learning_phase(0) # Testing phase model = load_model(self.getModelFile()) out_op_name = tfx.op_name(g, model.output) self._inputTensor = model.input.name self._outputTensor = model.output.name return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True)