Importing a Keras model into TensorFlow.js
Keras models (typically created via the Python API) may be saved in one of several formats. The "whole model" format can be converted to TensorFlow.js Layers format, which can be loaded directly into TensorFlow.js for inference or for further training.
The target TensorFlow.js Layers format is a directory containing a
model.json file and a set of sharded weight files in binary format. The
model.json file contains both the model topology (aka "architecture" or "graph": a description of the layers and how they are connected) and a manifest of the weight files.
Importing a Keras model into TensorFlow.js is a two-step process. First, convert an existing Keras model to TF.js Layers format, and then load it into TensorFlow.js.
Step 1. Convert an existing Keras model to TF.js Layers format
Keras models are usually saved via
model.save(filepath), which produces a single HDF5 (.h5) file containing both the model topology and the weights. To convert such a file to TF.js Layers format, run the following command, where
path/to/my_model.h5 is the source Keras .h5 file and
path/to/tfjs_target_dir is the target output directory for the TF.js files:
# bash tensorflowjs_converter --input_format keras \ path/to/my_model.h5 \ path/to/tfjs_target_dir
Alternative: Use the Python API to export directly to TF.js Layers format
If you have a Keras model in Python, you can export it directly to the TensorFlow.js Layers format as follows:
# Python import tensorflowjs as tfjs def train(...): model = keras.models.Sequential() # for example ... model.compile(...) model.fit(...) tfjs.converters.save_keras_model(model, tfjs_target_dir)
Step 2: Load the model into TensorFlow.js
Then load the model into TensorFlow.js by providing the URL to the model.json file:
Now the model is ready for inference, evaluation, or re-training. For instance, the loaded model can be immediately used to make a prediction:
Many of the TensorFlow.js Examples take this approach, using pretrained models that have been converted and hosted on Google Cloud Storage.
Note that you refer to the entire model using the
model.json, and then makes additional HTTP(S) requests to obtain the sharded weight files referenced in the
model.json weight manifest. This approach allows all of these files to be cached by the browser (and perhaps by additional caching servers on the internet), because the
model.json and the weight shards are each smaller than the typical cache file size limit. Thus a model is likely to load more quickly on subsequent occasions.