乐闻世界logo
搜索文章和话题

What Are the Callbacks in TensorFlow and How to Create Custom Callbacks

2月18日 18:03

Callbacks are powerful tools in TensorFlow for executing custom operations during the training process. They allow you to monitor, control, and modify the training process at different stages.

Built-in Callbacks

1. ModelCheckpoint - Save Model Checkpoints

python
from tensorflow.keras.callbacks import ModelCheckpoint # Save best model checkpoint = ModelCheckpoint( filepath='best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1 ) # Save model for each epoch checkpoint_epoch = ModelCheckpoint( filepath='model_{epoch:02d}.h5', save_freq='epoch', verbose=1 ) # Save only model weights checkpoint_weights = ModelCheckpoint( filepath='weights_{epoch:02d}.h5', save_weights_only=True, verbose=1 )

2. EarlyStopping - Early Stopping

python
from tensorflow.keras.callbacks import EarlyStopping # Early stopping based on validation loss early_stop = EarlyStopping( monitor='val_loss', patience=5, mode='min', restore_best_weights=True, verbose=1 ) # Early stopping based on validation accuracy early_stop_acc = EarlyStopping( monitor='val_accuracy', patience=3, mode='max', verbose=1 )

3. ReduceLROnPlateau - Learning Rate Decay

python
from tensorflow.keras.callbacks import ReduceLROnPlateau # Reduce learning rate when validation loss stops improving reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.1, # Multiply learning rate by 0.1 patience=3, mode='min', min_lr=1e-7, verbose=1 ) # Adjust learning rate based on accuracy reduce_lr_acc = ReduceLROnPlateau( monitor='val_accuracy', factor=0.5, patience=2, mode='max', verbose=1 )

4. TensorBoard - TensorBoard Logging

python
from tensorflow.keras.callbacks import TensorBoard import datetime # Create log directory with timestamp log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard = TensorBoard( log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, update_freq='epoch' )

5. LearningRateScheduler - Learning Rate Scheduling

python
from tensorflow.keras.callbacks import LearningRateScheduler import math # Define learning rate scheduling function def lr_scheduler(epoch, lr): if epoch < 10: return lr else: return lr * math.exp(-0.1) lr_schedule = LearningRateScheduler(lr_scheduler, verbose=1) # Use predefined learning rate decay def step_decay(epoch): initial_lr = 0.001 drop = 0.5 epochs_drop = 10.0 lrate = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop)) return lrate lr_step = LearningRateScheduler(step_decay, verbose=1)

6. CSVLogger - CSV Logging

python
from tensorflow.keras.callbacks import CSVLogger csv_logger = CSVLogger( 'training.log', separator=',', append=False )

7. ProgbarLogger - Progress Bar Logging

python
from tensorflow.keras.callbacks import ProgbarLogger progbar = ProgbarLogger( count_mode='steps', stateful_metrics=['loss', 'accuracy'] )

8. LambdaCallback - Custom Callback

python
from tensorflow.keras.callbacks import LambdaCallback # Simple custom callback lambda_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: print(f"Epoch {epoch} started"), on_epoch_end=lambda epoch, logs: print(f"Epoch {epoch} ended, Loss: {logs['loss']:.4f}"), on_batch_begin=lambda batch, logs: None, on_batch_end=lambda batch, logs: None, on_train_begin=lambda logs: print("Training started"), on_train_end=lambda logs: print("Training ended") )

9. RemoteMonitor - Remote Monitoring

python
from tensorflow.keras.callbacks import RemoteMonitor remote_monitor = RemoteMonitor( root='http://localhost:9000', path='/publish/epoch/end/', field='data', headers=None, send_as_json=False )

10. BackupAndRestore - Backup and Restore

python
from tensorflow.keras.callbacks import BackupAndRestore backup_restore = BackupAndRestore( backup_dir='backup', save_freq='epoch', delete_checkpoint=True )

Custom Callbacks

Basic Custom Callback

python
from tensorflow.keras.callbacks import Callback class CustomCallback(Callback): def on_train_begin(self, logs=None): print("Training started") def on_train_end(self, logs=None): print("Training ended") def on_epoch_begin(self, epoch, logs=None): print(f"Epoch {epoch} started") def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch} ended") print(f"Loss: {logs['loss']:.4f}") print(f"Accuracy: {logs['accuracy']:.4f}") def on_batch_begin(self, batch, logs=None): pass def on_batch_end(self, batch, logs=None): if batch % 100 == 0: print(f"Batch {batch}, Loss: {logs['loss']:.4f}")

Using Custom Callback

python
# Create custom callback instance custom_callback = CustomCallback() # Use during training model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[custom_callback] )

Advanced Custom Callback Examples

1. Learning Rate Recorder Callback

python
class LearningRateRecorder(Callback): def __init__(self): super(LearningRateRecorder, self).__init__() self.lr_history = [] def on_epoch_end(self, epoch, logs=None): # Get current learning rate lr = self.model.optimizer.learning_rate if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule): lr = lr(self.model.optimizer.iterations) self.lr_history.append(float(lr)) print(f"Epoch {epoch}: Learning Rate = {lr:.6f}") def get_lr_history(self): return self.lr_history

2. Gradient Monitor Callback

python
class GradientMonitor(Callback): def __init__(self, log_dir='logs/gradients'): super(GradientMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): # Calculate gradients with tf.GradientTape() as tape: predictions = self.model(x_train[:1]) loss = self.model.compiled_loss(y_train[:1], predictions) gradients = tape.gradient(loss, self.model.trainable_variables) # Log gradient norms with self.writer.as_default(): for i, grad in enumerate(gradients): if grad is not None: grad_norm = tf.norm(grad) tf.summary.scalar(f'gradient_norm_{i}', grad_norm, step=epoch)

3. Model Weight Monitor Callback

python
class WeightMonitor(Callback): def __init__(self, log_dir='logs/weights'): super(WeightMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): with self.writer.as_default(): for i, layer in enumerate(self.model.layers): if hasattr(layer, 'get_weights'): weights = layer.get_weights() for j, w in enumerate(weights): w_mean = tf.reduce_mean(w) w_std = tf.math.reduce_std(w) tf.summary.scalar(f'layer_{i}_weight_{j}_mean', w_mean, step=epoch) tf.summary.scalar(f'layer_{i}_weight_{j}_std', w_std, step=epoch)

4. Custom Early Stopping Callback

python
class CustomEarlyStopping(Callback): def __init__(self, monitor='val_loss', patience=5, min_delta=0): super(CustomEarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.best = None self.stopped_epoch = 0 def on_train_begin(self, logs=None): self.wait = 0 self.best = float('inf') if 'loss' in self.monitor else -float('inf') def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) if current is None: return if self.monitor == 'val_loss': if current < self.best - self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}") else: if current > self.best + self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}")

5. Mixed Precision Training Callback

python
class MixedPrecisionCallback(Callback): def __init__(self): super(MixedPrecisionCallback, self).__init__() self.loss_scale = 1.0 def on_batch_end(self, batch, logs=None): # Check if loss is NaN or Inf if logs is not None and 'loss' in logs: if tf.math.is_nan(logs['loss']) or tf.math.is_inf(logs['loss']): print(f"NaN/Inf detected at batch {batch}, reducing loss scale") self.loss_scale /= 2.0 # Reset optimizer state self.model.optimizer.set_weights([ w / 2.0 if w is not None else None for w in self.model.optimizer.get_weights() ])

6. Data Augmentation Callback

python
class DataAugmentationCallback(Callback): def __init__(self, augmentation_fn): super(DataAugmentationCallback, self).__init__() self.augmentation_fn = augmentation_fn def on_batch_begin(self, batch, logs=None): # Apply data augmentation during training if self.model.trainable: # Here you can access current batch data # Real application requires more complex implementation pass

7. Model Ensemble Callback

python
class ModelEnsembleCallback(Callback): def __init__(self, ensemble_size=5): super(ModelEnsembleCallback, self).__init__() self.ensemble_size = ensemble_size self.models = [] def on_epoch_end(self, epoch, logs=None): # Save model snapshot if epoch % 5 == 0 and len(self.models) < self.ensemble_size: model_copy = tf.keras.models.clone_model(self.model) model_copy.set_weights(self.model.get_weights()) self.models.append(model_copy) print(f"Saved model snapshot at epoch {epoch}") def predict_ensemble(self, x): # Ensemble prediction predictions = [model.predict(x) for model in self.models] return np.mean(predictions, axis=0)

Combining Callbacks

python
# Combine multiple callbacks callbacks = [ # Model checkpoint ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, verbose=1 ), # Early stopping EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True, verbose=1 ), # Learning rate decay ReduceLROnPlateau( monitor='val_loss', factor=0.1, patience=3, verbose=1 ), # TensorBoard TensorBoard(log_dir='logs/fit'), # Custom callback CustomCallback() ] # Train model model.fit( x_train, y_train, epochs=100, validation_data=(x_val, y_val), callbacks=callbacks )

Callback Execution Order

The execution order of callbacks is as follows:

  1. on_train_begin
  2. on_epoch_begin
  3. on_batch_begin
  4. on_batch_end
  5. on_epoch_end
  6. on_train_end

Callback Best Practices

1. Reasonably Set Monitoring Metrics

python
# Choose appropriate monitoring metrics based on task early_stop = EarlyStopping( monitor='val_accuracy' if classification else 'val_loss', patience=5, verbose=1 )

2. Save Best Model

python
# Always save the best model checkpoint = ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, mode='min' )

3. Use Learning Rate Scheduling

python
# Combine learning rate scheduling and early stopping callbacks = [ ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), EarlyStopping(monitor='val_loss', patience=10) ]

4. Monitor Training Process

python
# Use TensorBoard to monitor training tensorboard = TensorBoard( log_dir='logs/fit', histogram_freq=1, write_graph=True )

5. Avoid Excessive Logging

python
# Don't log information too frequently class EfficientCallback(Callback): def on_epoch_end(self, epoch, logs=None): if epoch % 5 == 0: # Log every 5 epochs print(f"Epoch {epoch}: Loss = {logs['loss']:.4f}")

6. Handle Exceptions

python
class RobustCallback(Callback): def on_batch_end(self, batch, logs=None): try: # Processing logic pass except Exception as e: print(f"Error in callback: {e}") # Don't interrupt training

Callback Application Scenarios

1. Long Training

python
# Use checkpoints and backup restore callbacks = [ ModelCheckpoint('checkpoint.h5', save_freq='epoch'), BackupAndRestore(backup_dir='backup') ]

2. Hyperparameter Tuning

python
# Use early stopping and learning rate scheduling callbacks = [ EarlyStopping(monitor='val_loss', patience=5), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2) ]

3. Experiment Tracking

python
# Use TensorBoard and CSVLogger callbacks = [ TensorBoard(log_dir='logs/experiment_1'), CSVLogger('experiment_1.csv') ]

4. Production Deployment

python
# Save best model and monitor performance callbacks = [ ModelCheckpoint('production_model.h5', save_best_only=True), CustomMonitoringCallback() ]

Summary

TensorFlow's callbacks provide powerful training control capabilities:

  • Built-in callbacks: Provide common training control functions
  • Custom callbacks: Implement specific training logic
  • Flexible combination: Can combine multiple callbacks
  • Execution order: Understand when callbacks execute
  • Best practices: Reasonably use callbacks to improve training efficiency

Mastering callbacks will help you better control and monitor the model training process.

标签:Tensorflow