乐闻世界logo
搜索文章和话题

What Are the Methods for Saving and Loading Models in TensorFlow and How to Deploy Models

2月18日 18:00

TensorFlow provides multiple methods for saving and loading models, as well as flexible model deployment options. Mastering these skills is crucial for deep learning applications in production environments.

Model Saving Formats

TensorFlow supports multiple model saving formats:

  1. SavedModel format: Recommended format for TensorFlow 2.x
  2. Keras H5 format: Traditional Keras model format
  3. TensorFlow Lite format: For mobile and embedded devices
  4. TensorFlow.js format: For web browsers

SavedModel Format

Saving Complete Model

python
import tensorflow as tf from tensorflow.keras import layers, models # Build model model = models.Sequential([ layers.Dense(64, activation='relu', input_shape=(10,)), layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') # Save as SavedModel format model.save('saved_model/my_model') # SavedModel directory structure: # saved_model/ # ├── saved_model.pb # ├── variables/ # └── assets/

Loading SavedModel

python
# Load model loaded_model = tf.keras.models.load_model('saved_model/my_model') # Use model predictions = loaded_model.predict(x_test)

Saving Specific Version

python
import tensorflow as tf # Save model with version model.save('saved_model/my_model/1') # Save multiple versions model.save('saved_model/my_model/2')

Keras H5 Format

Saving Complete Model

python
# Save as H5 format model.save('my_model.h5') # Save with optimizer state model.save('my_model_with_optimizer.h5', save_format='h5')

Loading H5 Model

python
# Load model loaded_model = tf.keras.models.load_model('my_model.h5') # Load and continue training loaded_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') loaded_model.fit(x_train, y_train, epochs=5)

Saving Only Model Architecture

python
# Save model architecture as JSON model_json = model.to_json() with open('model_architecture.json', 'w') as json_file: json_file.write(model_json) # Load architecture from JSON with open('model_architecture.json', 'r') as json_file: loaded_model_json = json_file.read() loaded_model = tf.keras.models.model_from_json(loaded_model_json) # Load weights loaded_model.load_weights('model_weights.h5')

Saving Only Model Weights

python
# Save weights model.save_weights('model_weights.h5') # Load weights model.load_weights('model_weights.h5') # Load into different model new_model = create_model() new_model.load_weights('model_weights.h5')

Checkpoints

Saving Checkpoints

python
from tensorflow.keras.callbacks import ModelCheckpoint # Create checkpoint callback checkpoint_callback = ModelCheckpoint( filepath='checkpoints/model_{epoch:02d}.h5', save_weights_only=False, save_best_only=True, monitor='val_loss', mode='min', verbose=1 ) # Save checkpoints during training model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[checkpoint_callback] )

Manually Saving Checkpoints

python
# Manually save checkpoint model.save_weights('checkpoints/ckpt') # Save optimizer state optimizer_state = tf.train.Checkpoint(optimizer=optimizer, model=model) optimizer_state.save('checkpoints/optimizer')

Restoring Checkpoints

python
# Restore checkpoint model.load_weights('checkpoints/ckpt') # Restore optimizer state optimizer_state = tf.train.Checkpoint(optimizer=optimizer, model=model) optimizer_state.restore('checkpoints/optimizer')

TensorFlow Lite Deployment

Converting to TFLite Model

python
import tensorflow as tf # Convert model converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Save TFLite model with open('model.tflite', 'wb') as f: f.write(tflite_model)

Optimizing TFLite Model

python
# Quantize model converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_quant_model = converter.convert() # Save quantized model with open('model_quant.tflite', 'wb') as f: f.write(tflite_quant_model)

Running TFLite Model in Python

python
import tensorflow as tf import numpy as np # Load TFLite model interpreter = tf.lite.Interpreter(model_path='model.tflite') interpreter.allocate_tensors() # Get input and output tensors input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Prepare input data input_data = np.array(np.random.random_sample(input_details[0]['shape']), dtype=np.float32) # Set input interpreter.set_tensor(input_details[0]['index'], input_data) # Run inference interpreter.invoke() # Get output output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)

Deploying on Mobile Devices

Android Deployment

java
import org.tensorflow.lite.Interpreter; // Load model Interpreter interpreter = new Interpreter(loadModelFile()); // Prepare input float[][] input = new float[1][10]; // Run inference float[][] output = new float[1][10]; interpreter.run(input, output);

iOS Deployment

swift
import TensorFlowLite // Load model guard let interpreter = try? Interpreter(modelPath: "model.tflite") else { fatalError("Failed to load model") } // Prepare input var input: [Float] = Array(repeating: 0.0, count: 10) // Run inference var output: [Float] = Array(repeating: 0.0, count: 10) try interpreter.copy(input, toInputAt: 0) try interpreter.invoke() try interpreter.copy(&output, fromOutputAt: 0)

TensorFlow.js Deployment

Converting to TensorFlow.js Model

bash
# Install tensorflowjs_converter pip install tensorflowjs # Convert model tensorflowjs_converter --input_format keras \ my_model.h5 \ tfjs_model

Using in Browser

html
<!DOCTYPE html> <html> <head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> </head> <body> <script> // Load model async function loadModel() { const model = await tf.loadLayersModel('tfjs_model/model.json'); return model; } // Run inference async function predict() { const model = await loadModel(); const input = tf.randomNormal([1, 10]); const output = model.predict(input); output.print(); } predict(); </script> </body> </html>

TensorFlow Serving Deployment

Exporting Model

python
import tensorflow as tf # Export model as SavedModel format model.save('serving_model/1')

Deploying with Docker

bash
# Pull TensorFlow Serving image docker pull tensorflow/serving # Run TensorFlow Serving docker run -p 8501:8501 \ --mount type=bind,source=$(pwd)/serving_model,target=/models/my_model \ -e MODEL_NAME=my_model \ -t tensorflow/serving &

Calling with REST API

python
import requests import json import numpy as np # Prepare input data input_data = np.random.random((1, 10)).tolist() # Send request response = requests.post( 'http://localhost:8501/v1/models/my_model:predict', json={'instances': input_data} ) # Get prediction results predictions = response.json()['predictions'] print(predictions)

Calling with gRPC

python
import grpc from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc import numpy as np # Create gRPC connection channel = grpc.insecure_channel('localhost:8500') stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) # Create prediction request request = predict_pb2.PredictRequest() request.model_spec.name = 'my_model' request.model_spec.signature_name = 'serving_default' # Set input data input_data = np.random.random((1, 10)).astype(np.float32) request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_data)) # Send request result = stub.Predict(request, timeout=10.0) print(result)

Cloud Platform Deployment

Google Cloud AI Platform

python
from google.cloud import aiplatform # Upload model model = aiplatform.Model.upload( display_name='my_model', artifact_uri='gs://my-bucket/model', serving_container_image_uri='us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest' ) # Deploy model endpoint = model.deploy( machine_type='n1-standard-4', min_replica_count=1, max_replica_count=5 )

AWS SageMaker

python
import sagemaker from sagemaker.tensorflow import TensorFlowModel # Create model model = TensorFlowModel( model_data='s3://my-bucket/model.tar.gz', role='arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole', framework_version='2.6.0' ) # Deploy model predictor = model.deploy( initial_instance_count=1, instance_type='ml.m5.xlarge' ) # Make predictions predictions = predictor.predict(input_data)

Model Version Management

Saving Multiple Versions

python
import os # Save different model versions version = 1 model.save(f'saved_model/my_model/{version}') # Update version version += 1 model.save(f'saved_model/my_model/{version}')

Loading Specific Version

python
# Load latest version latest_model = tf.keras.models.load_model('saved_model/my_model') # Load specific version version_1_model = tf.keras.models.load_model('saved_model/my_model/1') version_2_model = tf.keras.models.load_model('saved_model/my_model/2')

Model Optimization

Model Pruning

python
import tensorflow_model_optimization as tfmot # Define pruning model prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude # Apply pruning model_for_pruning = prune_low_magnitude(model, pruning_params) # Train pruning model model_for_pruning.fit(x_train, y_train, epochs=10) # Export pruned model model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning) model_for_export.save('pruned_model')

Model Quantization

python
# Post-training quantization converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_model = converter.convert() # Save quantized model with open('quantized_model.tflite', 'wb') as f: f.write(quantized_model)

Knowledge Distillation

python
# Define teacher and student models teacher_model = create_teacher_model() student_model = create_student_model() # Define distillation loss def distillation_loss(y_true, y_pred, teacher_pred, temperature=3): y_true_soft = tf.nn.softmax(y_true / temperature) y_pred_soft = tf.nn.softmax(y_pred / temperature) teacher_pred_soft = tf.nn.softmax(teacher_pred / temperature) loss = tf.keras.losses.KLDivergence()(y_true_soft, y_pred_soft) loss += tf.keras.losses.KLDivergence()(teacher_pred_soft, y_pred_soft) return loss # Train student model for x_batch, y_batch in train_dataset: with tf.GradientTape() as tape: teacher_pred = teacher_model(x_batch, training=False) student_pred = student_model(x_batch, training=True) loss = distillation_loss(y_batch, student_pred, teacher_pred) gradients = tape.gradient(loss, student_model.trainable_variables) optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

Best Practices

  1. Use SavedModel format: Recommended format for TensorFlow 2.x
  2. Version control: Create separate directories for each model version
  3. Model signatures: Define clear input and output signatures for models
  4. Test deployment: Thoroughly test models before deployment
  5. Monitor performance: Monitor model performance after deployment
  6. Security considerations: Protect model files and API endpoints
  7. Documentation: Record model usage methods and dependencies

Summary

TensorFlow provides a complete solution for model saving, loading, and deployment:

  • SavedModel: Recommended format for production environments
  • Keras H5: Quick prototyping and development
  • TensorFlow Lite: Mobile and embedded devices
  • TensorFlow.js: Web browser deployment
  • TensorFlow Serving: Production environment serving

Mastering these technologies will help you successfully deploy deep learning models from development to production environments.

标签:Tensorflow