Source code for sparkdl.transformers.keras_image
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import keras.backend as K
from keras.models import load_model
from pyspark.ml import Transformer
from pyspark.ml.param import Params, TypeConverters
import sparkdl.graph.utils as tfx
from sparkdl.transformers.keras_utils import KSessionWrap
from sparkdl.param import (
keyword_only, HasInputCol, HasOutputCol,
CanLoadImage, HasKerasModel, HasOutputMode)
from sparkdl.transformers.tf_image import TFImageTransformer
[docs]class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol,
CanLoadImage, HasKerasModel, HasOutputMode):
"""
Applies the Tensorflow-backed Keras model (specified by a file name) to
images (specified by the URI in the inputCol column) in the DataFrame.
Restrictions of the current API:
* see TFImageTransformer.
* Only supports Tensorflow-backed Keras models (no Theano).
"""
@keyword_only
def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
outputMode="vector"):
"""
__init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
outputMode="vector")
"""
super(KerasImageFileTransformer, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
self._inputTensor = None
self._outputTensor = None
[docs] @keyword_only
def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
outputMode="vector"):
"""
setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None,
outputMode="vector")
"""
kwargs = self._input_kwargs
self._set(**kwargs)
return self
def _transform(self, dataset):
graph = self._loadTFGraph()
image_df = self.loadImagesInternal(dataset, self.getInputCol())
assert self._inputTensor is not None, "self._inputTensor must be set"
assert self._outputTensor is not None, "self._outputTensor must be set"
transformer = TFImageTransformer(inputCol=self._loadedImageCol(),
outputCol=self.getOutputCol(), graph=graph,
inputTensor=self._inputTensor,
outputTensor=self._outputTensor,
outputMode=self.getOrDefault(self.outputMode))
return transformer.transform(image_df).drop(self._loadedImageCol())
def _loadTFGraph(self):
with KSessionWrap() as (sess, g):
assert K.backend() == "tensorflow", \
"Keras backend is not tensorflow but KerasImageTransformer only supports " + \
"tensorflow-backed Keras models."
with g.as_default():
K.set_learning_phase(0) # Testing phase
model = load_model(self.getModelFile())
out_op_name = tfx.op_name(g, model.output)
self._inputTensor = model.input.name
self._outputTensor = model.output.name
return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True)