乐闻世界logo
搜索文章和话题

TensorFlow 中的回调函数(Callbacks)有哪些,如何自定义回调函数

2月18日 18:03

回调函数是 TensorFlow 中用于在训练过程中执行自定义操作的强大工具。它们允许你在训练的不同阶段监控、控制和修改训练过程。

内置回调函数

1. ModelCheckpoint - 保存模型检查点

python
from tensorflow.keras.callbacks import ModelCheckpoint # 保存最佳模型 checkpoint = ModelCheckpoint( filepath='best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1 ) # 保存每个 epoch 的模型 checkpoint_epoch = ModelCheckpoint( filepath='model_{epoch:02d}.h5', save_freq='epoch', verbose=1 ) # 只保存模型权重 checkpoint_weights = ModelCheckpoint( filepath='weights_{epoch:02d}.h5', save_weights_only=True, verbose=1 )

2. EarlyStopping - 早停

python
from tensorflow.keras.callbacks import EarlyStopping # 基于验证损失早停 early_stop = EarlyStopping( monitor='val_loss', patience=5, mode='min', restore_best_weights=True, verbose=1 ) # 基于验证准确率早停 early_stop_acc = EarlyStopping( monitor='val_accuracy', patience=3, mode='max', verbose=1 )

3. ReduceLROnPlateau - 学习率衰减

python
from tensorflow.keras.callbacks import ReduceLROnPlateau # 当验证损失不再下降时降低学习率 reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.1, # 学习率乘以 0.1 patience=3, mode='min', min_lr=1e-7, verbose=1 ) # 基于准确率调整学习率 reduce_lr_acc = ReduceLROnPlateau( monitor='val_accuracy', factor=0.5, patience=2, mode='max', verbose=1 )

4. TensorBoard - TensorBoard 日志

python
from tensorflow.keras.callbacks import TensorBoard import datetime # 创建带时间戳的日志目录 log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard = TensorBoard( log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, update_freq='epoch' )

5. LearningRateScheduler - 学习率调度

python
from tensorflow.keras.callbacks import LearningRateScheduler import math # 定义学习率调度函数 def lr_scheduler(epoch, lr): if epoch < 10: return lr else: return lr * math.exp(-0.1) lr_schedule = LearningRateScheduler(lr_scheduler, verbose=1) # 使用预定义的学习率衰减 def step_decay(epoch): initial_lr = 0.001 drop = 0.5 epochs_drop = 10.0 lrate = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop)) return lrate lr_step = LearningRateScheduler(step_decay, verbose=1)

6. CSVLogger - CSV 日志记录

python
from tensorflow.keras.callbacks import CSVLogger csv_logger = CSVLogger( 'training.log', separator=',', append=False )

7. ProgbarLogger - 进度条日志

python
from tensorflow.keras.callbacks import ProgbarLogger progbar = ProgbarLogger( count_mode='steps', stateful_metrics=['loss', 'accuracy'] )

8. LambdaCallback - 自定义回调

python
from tensorflow.keras.callbacks import LambdaCallback # 简单的自定义回调 lambda_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: print(f"Epoch {epoch} 开始"), on_epoch_end=lambda epoch, logs: print(f"Epoch {epoch} 结束, Loss: {logs['loss']:.4f}"), on_batch_begin=lambda batch, logs: None, on_batch_end=lambda batch, logs: None, on_train_begin=lambda logs: print("训练开始"), on_train_end=lambda logs: print("训练结束") )

9. RemoteMonitor - 远程监控

python
from tensorflow.keras.callbacks import RemoteMonitor remote_monitor = RemoteMonitor( root='http://localhost:9000', path='/publish/epoch/end/', field='data', headers=None, send_as_json=False )

10. BackupAndRestore - 备份和恢复

python
from tensorflow.keras.callbacks import BackupAndRestore backup_restore = BackupAndRestore( backup_dir='backup', save_freq='epoch', delete_checkpoint=True )

自定义回调函数

基本自定义回调

python
from tensorflow.keras.callbacks import Callback class CustomCallback(Callback): def on_train_begin(self, logs=None): print("训练开始") def on_train_end(self, logs=None): print("训练结束") def on_epoch_begin(self, epoch, logs=None): print(f"Epoch {epoch} 开始") def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch} 结束") print(f"Loss: {logs['loss']:.4f}") print(f"Accuracy: {logs['accuracy']:.4f}") def on_batch_begin(self, batch, logs=None): pass def on_batch_end(self, batch, logs=None): if batch % 100 == 0: print(f"Batch {batch}, Loss: {logs['loss']:.4f}")

使用自定义回调

python
# 创建自定义回调实例 custom_callback = CustomCallback() # 在训练时使用 model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[custom_callback] )

高级自定义回调示例

1. 学习率记录回调

python
class LearningRateRecorder(Callback): def __init__(self): super(LearningRateRecorder, self).__init__() self.lr_history = [] def on_epoch_end(self, epoch, logs=None): # 获取当前学习率 lr = self.model.optimizer.learning_rate if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule): lr = lr(self.model.optimizer.iterations) self.lr_history.append(float(lr)) print(f"Epoch {epoch}: Learning Rate = {lr:.6f}") def get_lr_history(self): return self.lr_history

2. 梯度监控回调

python
class GradientMonitor(Callback): def __init__(self, log_dir='logs/gradients'): super(GradientMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): # 计算梯度 with tf.GradientTape() as tape: predictions = self.model(x_train[:1]) loss = self.model.compiled_loss(y_train[:1], predictions) gradients = tape.gradient(loss, self.model.trainable_variables) # 记录梯度范数 with self.writer.as_default(): for i, grad in enumerate(gradients): if grad is not None: grad_norm = tf.norm(grad) tf.summary.scalar(f'gradient_norm_{i}', grad_norm, step=epoch)

3. 模型权重监控回调

python
class WeightMonitor(Callback): def __init__(self, log_dir='logs/weights'): super(WeightMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): with self.writer.as_default(): for i, layer in enumerate(self.model.layers): if hasattr(layer, 'get_weights'): weights = layer.get_weights() for j, w in enumerate(weights): w_mean = tf.reduce_mean(w) w_std = tf.math.reduce_std(w) tf.summary.scalar(f'layer_{i}_weight_{j}_mean', w_mean, step=epoch) tf.summary.scalar(f'layer_{i}_weight_{j}_std', w_std, step=epoch)

4. 自定义早停回调

python
class CustomEarlyStopping(Callback): def __init__(self, monitor='val_loss', patience=5, min_delta=0): super(CustomEarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.best = None self.stopped_epoch = 0 def on_train_begin(self, logs=None): self.wait = 0 self.best = float('inf') if 'loss' in self.monitor else -float('inf') def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) if current is None: return if self.monitor == 'val_loss': if current < self.best - self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}") else: if current > self.best + self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}")

5. 混合精度训练回调

python
class MixedPrecisionCallback(Callback): def __init__(self): super(MixedPrecisionCallback, self).__init__() self.loss_scale = 1.0 def on_batch_end(self, batch, logs=None): # 检查损失是否为 NaN 或 Inf if logs is not None and 'loss' in logs: if tf.math.is_nan(logs['loss']) or tf.math.is_inf(logs['loss']): print(f"NaN/Inf detected at batch {batch}, reducing loss scale") self.loss_scale /= 2.0 # 重置优化器状态 self.model.optimizer.set_weights([ w / 2.0 if w is not None else None for w in self.model.optimizer.get_weights() ])

6. 数据增强回调

python
class DataAugmentationCallback(Callback): def __init__(self, augmentation_fn): super(DataAugmentationCallback, self).__init__() self.augmentation_fn = augmentation_fn def on_batch_begin(self, batch, logs=None): # 在训练时应用数据增强 if self.model.trainable: # 这里可以访问当前批次的数据 # 实际应用中需要更复杂的实现 pass

7. 模型集成回调

python
class ModelEnsembleCallback(Callback): def __init__(self, ensemble_size=5): super(ModelEnsembleCallback, self).__init__() self.ensemble_size = ensemble_size self.models = [] def on_epoch_end(self, epoch, logs=None): # 保存模型快照 if epoch % 5 == 0 and len(self.models) < self.ensemble_size: model_copy = tf.keras.models.clone_model(self.model) model_copy.set_weights(self.model.get_weights()) self.models.append(model_copy) print(f"Saved model snapshot at epoch {epoch}") def predict_ensemble(self, x): # 集成预测 predictions = [model.predict(x) for model in self.models] return np.mean(predictions, axis=0)

回调函数组合使用

python
# 组合多个回调函数 callbacks = [ # 模型检查点 ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, verbose=1 ), # 早停 EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True, verbose=1 ), # 学习率衰减 ReduceLROnPlateau( monitor='val_loss', factor=0.1, patience=3, verbose=1 ), # TensorBoard TensorBoard(log_dir='logs/fit'), # 自定义回调 CustomCallback() ] # 训练模型 model.fit( x_train, y_train, epochs=100, validation_data=(x_val, y_val), callbacks=callbacks )

回调函数执行顺序

回调函数的执行顺序如下:

  1. on_train_begin
  2. on_epoch_begin
  3. on_batch_begin
  4. on_batch_end
  5. on_epoch_end
  6. on_train_end

回调函数最佳实践

1. 合理设置监控指标

python
# 根据任务选择合适的监控指标 early_stop = EarlyStopping( monitor='val_accuracy' if classification else 'val_loss', patience=5, verbose=1 )

2. 保存最佳模型

python
# 始终保存最佳模型 checkpoint = ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, mode='min' )

3. 使用学习率调度

python
# 结合学习率调度和早停 callbacks = [ ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), EarlyStopping(monitor='val_loss', patience=10) ]

4. 监控训练过程

python
# 使用 TensorBoard 监控训练 tensorboard = TensorBoard( log_dir='logs/fit', histogram_freq=1, write_graph=True )

5. 避免过度记录

python
# 不要过于频繁地记录信息 class EfficientCallback(Callback): def on_epoch_end(self, epoch, logs=None): if epoch % 5 == 0: # 每 5 个 epoch 记录一次 print(f"Epoch {epoch}: Loss = {logs['loss']:.4f}")

6. 处理异常情况

python
class RobustCallback(Callback): def on_batch_end(self, batch, logs=None): try: # 处理逻辑 pass except Exception as e: print(f"Error in callback: {e}") # 不要中断训练

回调函数应用场景

1. 长时间训练

python
# 使用检查点和备份恢复 callbacks = [ ModelCheckpoint('checkpoint.h5', save_freq='epoch'), BackupAndRestore(backup_dir='backup') ]

2. 超参数调优

python
# 使用早停和学习率调度 callbacks = [ EarlyStopping(monitor='val_loss', patience=5), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2) ]

3. 实验跟踪

python
# 使用 TensorBoard 和 CSVLogger callbacks = [ TensorBoard(log_dir='logs/experiment_1'), CSVLogger('experiment_1.csv') ]

4. 生产部署

python
# 保存最佳模型并监控性能 callbacks = [ ModelCheckpoint('production_model.h5', save_best_only=True), CustomMonitoringCallback() ]

总结

TensorFlow 的回调函数提供了强大的训练控制能力:

  • 内置回调:提供常用的训练控制功能
  • 自定义回调:实现特定的训练逻辑
  • 灵活组合:可以组合多个回调函数
  • 执行顺序:了解回调函数的执行时机
  • 最佳实践:合理使用回调函数提高训练效率

掌握回调函数将帮助你更好地控制和监控模型训练过程。

标签:Tensorflow