乐闻世界logo
搜索文章和话题

TensorFlow 中的模型保存和加载有哪些方法,如何进行模型部署

2月18日 18:00

TensorFlow 提供了多种模型保存和加载的方法,以及灵活的模型部署选项。掌握这些技能对于生产环境中的深度学习应用至关重要。

模型保存格式

TensorFlow 支持多种模型保存格式:

  1. SavedModel 格式:TensorFlow 2.x 推荐的格式
  2. Keras H5 格式:传统的 Keras 模型格式
  3. TensorFlow Lite 格式:用于移动设备和嵌入式设备
  4. TensorFlow.js 格式:用于 Web 浏览器

SavedModel 格式

保存完整模型

python
import tensorflow as tf from tensorflow.keras import layers, models # 构建模型 model = models.Sequential([ layers.Dense(64, activation='relu', input_shape=(10,)), layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') # 保存为 SavedModel 格式 model.save('saved_model/my_model') # SavedModel 目录结构: # saved_model/ # ├── saved_model.pb # ├── variables/ # └── assets/

加载 SavedModel

python
# 加载模型 loaded_model = tf.keras.models.load_model('saved_model/my_model') # 使用模型 predictions = loaded_model.predict(x_test)

保存特定版本

python
import tensorflow as tf # 保存模型并指定版本 model.save('saved_model/my_model/1') # 保存多个版本 model.save('saved_model/my_model/2')

Keras H5 格式

保存完整模型

python
# 保存为 H5 格式 model.save('my_model.h5') # 保存时包含优化器状态 model.save('my_model_with_optimizer.h5', save_format='h5')

加载 H5 模型

python
# 加载模型 loaded_model = tf.keras.models.load_model('my_model.h5') # 加载并继续训练 loaded_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') loaded_model.fit(x_train, y_train, epochs=5)

只保存模型架构

python
# 保存模型架构为 JSON model_json = model.to_json() with open('model_architecture.json', 'w') as json_file: json_file.write(model_json) # 从 JSON 加载架构 with open('model_architecture.json', 'r') as json_file: loaded_model_json = json_file.read() loaded_model = tf.keras.models.model_from_json(loaded_model_json) # 加载权重 loaded_model.load_weights('model_weights.h5')

只保存模型权重

python
# 保存权重 model.save_weights('model_weights.h5') # 加载权重 model.load_weights('model_weights.h5') # 加载到不同的模型 new_model = create_model() new_model.load_weights('model_weights.h5')

检查点(Checkpoint)

保存检查点

python
from tensorflow.keras.callbacks import ModelCheckpoint # 创建检查点回调 checkpoint_callback = ModelCheckpoint( filepath='checkpoints/model_{epoch:02d}.h5', save_weights_only=False, save_best_only=True, monitor='val_loss', mode='min', verbose=1 ) # 训练时保存检查点 model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[checkpoint_callback] )

手动保存检查点

python
# 手动保存检查点 model.save_weights('checkpoints/ckpt') # 保存优化器状态 optimizer_state = tf.train.Checkpoint(optimizer=optimizer, model=model) optimizer_state.save('checkpoints/optimizer')

恢复检查点

python
# 恢复检查点 model.load_weights('checkpoints/ckpt') # 恢复优化器状态 optimizer_state = tf.train.Checkpoint(optimizer=optimizer, model=model) optimizer_state.restore('checkpoints/optimizer')

TensorFlow Lite 部署

转换为 TFLite 模型

python
import tensorflow as tf # 转换模型 converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # 保存 TFLite 模型 with open('model.tflite', 'wb') as f: f.write(tflite_model)

优化 TFLite 模型

python
# 量化模型 converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_quant_model = converter.convert() # 保存量化模型 with open('model_quant.tflite', 'wb') as f: f.write(tflite_quant_model)

在 Python 中运行 TFLite 模型

python
import tensorflow as tf import numpy as np # 加载 TFLite 模型 interpreter = tf.lite.Interpreter(model_path='model.tflite') interpreter.allocate_tensors() # 获取输入输出张量 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # 准备输入数据 input_data = np.array(np.random.random_sample(input_details[0]['shape']), dtype=np.float32) # 设置输入 interpreter.set_tensor(input_details[0]['index'], input_data) # 运行推理 interpreter.invoke() # 获取输出 output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)

在移动设备上部署

Android 部署

java
import org.tensorflow.lite.Interpreter; // 加载模型 Interpreter interpreter = new Interpreter(loadModelFile()); // 准备输入 float[][] input = new float[1][10]; // 运行推理 float[][] output = new float[1][10]; interpreter.run(input, output);

iOS 部署

swift
import TensorFlowLite // 加载模型 guard let interpreter = try? Interpreter(modelPath: "model.tflite") else { fatalError("Failed to load model") } // 准备输入 var input: [Float] = Array(repeating: 0.0, count: 10) // 运行推理 var output: [Float] = Array(repeating: 0.0, count: 10) try interpreter.copy(input, toInputAt: 0) try interpreter.invoke() try interpreter.copy(&output, fromOutputAt: 0)

TensorFlow.js 部署

转换为 TensorFlow.js 模型

bash
# 安装 tensorflowjs_converter pip install tensorflowjs # 转换模型 tensorflowjs_converter --input_format keras \ my_model.h5 \ tfjs_model

在浏览器中使用

html
<!DOCTYPE html> <html> <head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> </head> <body> <script> // 加载模型 async function loadModel() { const model = await tf.loadLayersModel('tfjs_model/model.json'); return model; } // 运行推理 async function predict() { const model = await loadModel(); const input = tf.randomNormal([1, 10]); const output = model.predict(input); output.print(); } predict(); </script> </body> </html>

TensorFlow Serving 部署

导出模型

python
import tensorflow as tf # 导出模型为 SavedModel 格式 model.save('serving_model/1')

使用 Docker 部署

bash
# 拉取 TensorFlow Serving 镜像 docker pull tensorflow/serving # 运行 TensorFlow Serving docker run -p 8501:8501 \ --mount type=bind,source=$(pwd)/serving_model,target=/models/my_model \ -e MODEL_NAME=my_model \ -t tensorflow/serving &

使用 REST API 调用

python
import requests import json import numpy as np # 准备输入数据 input_data = np.random.random((1, 10)).tolist() # 发送请求 response = requests.post( 'http://localhost:8501/v1/models/my_model:predict', json={'instances': input_data} ) # 获取预测结果 predictions = response.json()['predictions'] print(predictions)

使用 gRPC 调用

python
import grpc from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc import numpy as np # 创建 gRPC 连接 channel = grpc.insecure_channel('localhost:8500') stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) # 创建预测请求 request = predict_pb2.PredictRequest() request.model_spec.name = 'my_model' request.model_spec.signature_name = 'serving_default' # 设置输入数据 input_data = np.random.random((1, 10)).astype(np.float32) request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(input_data)) # 发送请求 result = stub.Predict(request, timeout=10.0) print(result)

云平台部署

Google Cloud AI Platform

python
from google.cloud import aiplatform # 上传模型 model = aiplatform.Model.upload( display_name='my_model', artifact_uri='gs://my-bucket/model', serving_container_image_uri='us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest' ) # 部署模型 endpoint = model.deploy( machine_type='n1-standard-4', min_replica_count=1, max_replica_count=5 )

AWS SageMaker

python
import sagemaker from sagemaker.tensorflow import TensorFlowModel # 创建模型 model = TensorFlowModel( model_data='s3://my-bucket/model.tar.gz', role='arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole', framework_version='2.6.0' ) # 部署模型 predictor = model.deploy( initial_instance_count=1, instance_type='ml.m5.xlarge' ) # 进行预测 predictions = predictor.predict(input_data)

模型版本管理

保存多个版本

python
import os # 保存不同版本的模型 version = 1 model.save(f'saved_model/my_model/{version}') # 更新版本 version += 1 model.save(f'saved_model/my_model/{version}')

加载特定版本

python
# 加载最新版本 latest_model = tf.keras.models.load_model('saved_model/my_model') # 加载特定版本 version_1_model = tf.keras.models.load_model('saved_model/my_model/1') version_2_model = tf.keras.models.load_model('saved_model/my_model/2')

模型优化

模型剪枝

python
import tensorflow_model_optimization as tfmot # 定义剪枝模型 prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude # 应用剪枝 model_for_pruning = prune_low_magnitude(model, pruning_params) # 训练剪枝模型 model_for_pruning.fit(x_train, y_train, epochs=10) # 导出剪枝后的模型 model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning) model_for_export.save('pruned_model')

模型量化

python
# 训练后量化 converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] quantized_model = converter.convert() # 保存量化模型 with open('quantized_model.tflite', 'wb') as f: f.write(quantized_model)

知识蒸馏

python
# 定义教师模型和学生模型 teacher_model = create_teacher_model() student_model = create_student_model() # 定义蒸馏损失 def distillation_loss(y_true, y_pred, teacher_pred, temperature=3): y_true_soft = tf.nn.softmax(y_true / temperature) y_pred_soft = tf.nn.softmax(y_pred / temperature) teacher_pred_soft = tf.nn.softmax(teacher_pred / temperature) loss = tf.keras.losses.KLDivergence()(y_true_soft, y_pred_soft) loss += tf.keras.losses.KLDivergence()(teacher_pred_soft, y_pred_soft) return loss # 训练学生模型 for x_batch, y_batch in train_dataset: with tf.GradientTape() as tape: teacher_pred = teacher_model(x_batch, training=False) student_pred = student_model(x_batch, training=True) loss = distillation_loss(y_batch, student_pred, teacher_pred) gradients = tape.gradient(loss, student_model.trainable_variables) optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

最佳实践

  1. 使用 SavedModel 格式:TensorFlow 2.x 推荐的格式
  2. 版本控制:为每个模型版本创建单独的目录
  3. 模型签名:为模型定义清晰的输入输出签名
  4. 测试部署:在部署前充分测试模型
  5. 监控性能:监控部署后的模型性能
  6. 安全考虑:保护模型文件和 API 端点
  7. 文档记录:记录模型的使用方法和依赖项

总结

TensorFlow 提供了完整的模型保存、加载和部署解决方案:

  • SavedModel:生产环境推荐格式
  • Keras H5:快速原型开发
  • TensorFlow Lite:移动和嵌入式设备
  • TensorFlow.js:Web 浏览器部署
  • TensorFlow Serving:生产环境服务

掌握这些技术将帮助你将深度学习模型从开发环境成功部署到生产环境。

标签:Tensorflow