6月5日 18:34

What Are the Callbacks in TensorFlow and How to Create Custom Callbacks

Callbacks are powerful tools in TensorFlow for executing custom operations during the training process. They allow you to monitor, control, and modify the training process at different stages.

Built-in Callbacks

1. ModelCheckpoint - Save Model Checkpoints

python
from tensorflow.keras.callbacks import ModelCheckpoint # Save best model checkpoint = ModelCheckpoint( filepath='best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1 ) # Save model for each epoch checkpoint_epoch = ModelCheckpoint( filepath='model_{epoch:02d}.h5', save_freq='epoch', verbose=1 ) # Save only model weights checkpoint_weights = ModelCheckpoint( filepath='weights_{epoch:02d}.h5', save_weights_only=True, verbose=1 )

2. EarlyStopping - Early Stopping

python
from tensorflow.keras.callbacks import EarlyStopping # Early stopping based on validation loss early_stop = EarlyStopping( monitor='val_loss', patience=5, mode='min', restore_best_weights=True, verbose=1 ) # Early stopping based on validation accuracy early_stop_acc = EarlyStopping( monitor='val_accuracy', patience=3, mode='max', verbose=1 )

3. ReduceLROnPlateau - Learning Rate Decay

python
from tensorflow.keras.callbacks import ReduceLROnPlateau # Reduce learning rate when validation loss stops improving reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.1, # Multiply learning rate by 0.1 patience=3, mode='min', min_lr=1e-7, verbose=1 ) # Adjust learning rate based on accuracy reduce_lr_acc = ReduceLROnPlateau( monitor='val_accuracy', factor=0.5, patience=2, mode='max', verbose=1 )

4. TensorBoard - TensorBoard Logging

python
from tensorflow.keras.callbacks import TensorBoard import datetime # Create log directory with timestamp log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard = TensorBoard( log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, update_freq='epoch' )

5. LearningRateScheduler - Learning Rate Scheduling

python
from tensorflow.keras.callbacks import LearningRateScheduler import math # Define learning rate scheduling function def lr_scheduler(epoch, lr): if epoch < 10: return lr else: return lr * math.exp(-0.1) lr_schedule = LearningRateScheduler(lr_scheduler, verbose=1) # Use predefined learning rate decay def step_decay(epoch): initial_lr = 0.001 drop = 0.5 epochs_drop = 10.0 lrate = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop)) return lrate lr_step = LearningRateScheduler(step_decay, verbose=1)

6. CSVLogger - CSV Logging

python
from tensorflow.keras.callbacks import CSVLogger csv_logger = CSVLogger( 'training.log', separator=',', append=False )

7. ProgbarLogger - Progress Bar Logging

python
from tensorflow.keras.callbacks import ProgbarLogger progbar = ProgbarLogger( count_mode='steps', stateful_metrics=['loss', 'accuracy'] )

8. LambdaCallback - Custom Callback

python
from tensorflow.keras.callbacks import LambdaCallback # Simple custom callback lambda_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: print(f"Epoch {epoch} started"), on_epoch_end=lambda epoch, logs: print(f"Epoch {epoch} ended, Loss: {logs['loss']:.4f}"), on_batch_begin=lambda batch, logs: None, on_batch_end=lambda batch, logs: None, on_train_begin=lambda logs: print("Training started"), on_train_end=lambda logs: print("Training ended") )

9. RemoteMonitor - Remote Monitoring

python
from tensorflow.keras.callbacks import RemoteMonitor remote_monitor = RemoteMonitor( root='http://localhost:9000', path='/publish/epoch/end/', field='data', headers=None, send_as_json=False )

10. BackupAndRestore - Backup and Restore

python
from tensorflow.keras.callbacks import BackupAndRestore backup_restore = BackupAndRestore( backup_dir='backup', save_freq='epoch', delete_checkpoint=True )

Custom Callbacks

Basic Custom Callback

python
from tensorflow.keras.callbacks import Callback class CustomCallback(Callback): def on_train_begin(self, logs=None): print("Training started") def on_train_end(self, logs=None): print("Training ended") def on_epoch_begin(self, epoch, logs=None): print(f"Epoch {epoch} started") def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch} ended") print(f"Loss: {logs['loss']:.4f}") print(f"Accuracy: {logs['accuracy']:.4f}") def on_batch_begin(self, batch, logs=None): pass def on_batch_end(self, batch, logs=None): if batch % 100 == 0: print(f"Batch {batch}, Loss: {logs['loss']:.4f}")

Using Custom Callback

python
# Create custom callback instance custom_callback = CustomCallback() # Use during training model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[custom_callback] )

Advanced Custom Callback Examples

1. Learning Rate Recorder Callback

python
class LearningRateRecorder(Callback): def __init__(self): super(LearningRateRecorder, self).__init__() self.lr_history = [] def on_epoch_end(self, epoch, logs=None): # Get current learning rate lr = self.model.optimizer.learning_rate if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule): lr = lr(self.model.optimizer.iterations) self.lr_history.append(float(lr)) print(f"Epoch {epoch}: Learning Rate = {lr:.6f}") def get_lr_history(self): return self.lr_history

2. Gradient Monitor Callback

python
class GradientMonitor(Callback): def __init__(self, log_dir='logs/gradients'): super(GradientMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): # Calculate gradients with tf.GradientTape() as tape: predictions = self.model(x_train[:1]) loss = self.model.compiled_loss(y_train[:1], predictions) gradients = tape.gradient(loss, self.model.trainable_variables) # Log gradient norms with self.writer.as_default(): for i, grad in enumerate(gradients): if grad is not None: grad_norm = tf.norm(grad) tf.summary.scalar(f'gradient_norm_{i}', grad_norm, step=epoch)

3. Model Weight Monitor Callback

python
class WeightMonitor(Callback): def __init__(self, log_dir='logs/weights'): super(WeightMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): with self.writer.as_default(): for i, layer in enumerate(self.model.layers): if hasattr(layer, 'get_weights'): weights = layer.get_weights() for j, w in enumerate(weights): w_mean = tf.reduce_mean(w) w_std = tf.math.reduce_std(w) tf.summary.scalar(f'layer_{i}_weight_{j}_mean', w_mean, step=epoch) tf.summary.scalar(f'layer_{i}_weight_{j}_std', w_std, step=epoch)

4. Custom Early Stopping Callback

python
class CustomEarlyStopping(Callback): def __init__(self, monitor='val_loss', patience=5, min_delta=0): super(CustomEarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.best = None self.stopped_epoch = 0 def on_train_begin(self, logs=None): self.wait = 0 self.best = float('inf') if 'loss' in self.monitor else -float('inf') def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) if current is None: return if self.monitor == 'val_loss': if current < self.best - self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}") else: if current > self.best + self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}")

5. Mixed Precision Training Callback

python
class MixedPrecisionCallback(Callback): def __init__(self): super(MixedPrecisionCallback, self).__init__() self.loss_scale = 1.0 def on_batch_end(self, batch, logs=None): # Check if loss is NaN or Inf if logs is not None and 'loss' in logs: if tf.math.is_nan(logs['loss']) or tf.math.is_inf(logs['loss']): print(f"NaN/Inf detected at batch {batch}, reducing loss scale") self.loss_scale /= 2.0 # Reset optimizer state self.model.optimizer.set_weights([ w / 2.0 if w is not None else None for w in self.model.optimizer.get_weights() ])

6. Data Augmentation Callback

python
class DataAugmentationCallback(Callback): def __init__(self, augmentation_fn): super(DataAugmentationCallback, self).__init__() self.augmentation_fn = augmentation_fn def on_batch_begin(self, batch, logs=None): # Apply data augmentation during training if self.model.trainable: # Here you can access current batch data # Real application requires more complex implementation pass

7. Model Ensemble Callback

python
class ModelEnsembleCallback(Callback): def __init__(self, ensemble_size=5): super(ModelEnsembleCallback, self).__init__() self.ensemble_size = ensemble_size self.models = [] def on_epoch_end(self, epoch, logs=None): # Save model snapshot if epoch % 5 == 0 and len(self.models) < self.ensemble_size: model_copy = tf.keras.models.clone_model(self.model) model_copy.set_weights(self.model.get_weights()) self.models.append(model_copy) print(f"Saved model snapshot at epoch {epoch}") def predict_ensemble(self, x): # Ensemble prediction predictions = [model.predict(x) for model in self.models] return np.mean(predictions, axis=0)

Combining Callbacks

python
# Combine multiple callbacks callbacks = [ # Model checkpoint ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, verbose=1 ), # Early stopping EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True, verbose=1 ), # Learning rate decay ReduceLROnPlateau( monitor='val_loss', factor=0.1, patience=3, verbose=1 ), # TensorBoard TensorBoard(log_dir='logs/fit'), # Custom callback CustomCallback() ] # Train model model.fit( x_train, y_train, epochs=100, validation_data=(x_val, y_val), callbacks=callbacks )

Callback Execution Order

The execution order of callbacks is as follows:

  1. on_train_begin
  2. on_epoch_begin
  3. on_batch_begin
  4. on_batch_end
  5. on_epoch_end
  6. on_train_end

Callback Best Practices

1. Reasonably Set Monitoring Metrics

python
# Choose appropriate monitoring metrics based on task early_stop = EarlyStopping( monitor='val_accuracy' if classification else 'val_loss', patience=5, verbose=1 )

2. Save Best Model

python
# Always save the best model checkpoint = ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, mode='min' )

3. Use Learning Rate Scheduling

python
# Combine learning rate scheduling and early stopping callbacks = [ ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), EarlyStopping(monitor='val_loss', patience=10) ]

4. Monitor Training Process

python
# Use TensorBoard to monitor training tensorboard = TensorBoard( log_dir='logs/fit', histogram_freq=1, write_graph=True )

5. Avoid Excessive Logging

python
# Don't log information too frequently class EfficientCallback(Callback): def on_epoch_end(self, epoch, logs=None): if epoch % 5 == 0: # Log every 5 epochs print(f"Epoch {epoch}: Loss = {logs['loss']:.4f}")

6. Handle Exceptions

python
class RobustCallback(Callback): def on_batch_end(self, batch, logs=None): try: # Processing logic pass except Exception as e: print(f"Error in callback: {e}") # Don't interrupt training

Callback Application Scenarios

1. Long Training

python
# Use checkpoints and backup restore callbacks = [ ModelCheckpoint('checkpoint.h5', save_freq='epoch'), BackupAndRestore(backup_dir='backup') ]

2. Hyperparameter Tuning

python
# Use early stopping and learning rate scheduling callbacks = [ EarlyStopping(monitor='val_loss', patience=5), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2) ]

3. Experiment Tracking

python
# Use TensorBoard and CSVLogger callbacks = [ TensorBoard(log_dir='logs/experiment_1'), CSVLogger('experiment_1.csv') ]

4. Production Deployment

python
# Save best model and monitor performance callbacks = [ ModelCheckpoint('production_model.h5', save_best_only=True), CustomMonitoringCallback() ]

Summary

TensorFlow's callbacks provide powerful training control capabilities:

  • Built-in callbacks: Provide common training control functions
  • Custom callbacks: Implement specific training logic
  • Flexible combination: Can combine multiple callbacks
  • Execution order: Understand when callbacks execute
  • Best practices: Reasonably use callbacks to improve training efficiency

Mastering callbacks will help you better control and monitor the model training process.

标签:Tensorflow