6月5日 18:34
What Are the Callbacks in TensorFlow and How to Create Custom Callbacks
Callbacks are powerful tools in TensorFlow for executing custom operations during the training process. They allow you to monitor, control, and modify the training process at different stages.
Built-in Callbacks
1. ModelCheckpoint - Save Model Checkpoints
pythonfrom tensorflow.keras.callbacks import ModelCheckpoint # Save best model checkpoint = ModelCheckpoint( filepath='best_model.h5', monitor='val_loss', save_best_only=True, mode='min', verbose=1 ) # Save model for each epoch checkpoint_epoch = ModelCheckpoint( filepath='model_{epoch:02d}.h5', save_freq='epoch', verbose=1 ) # Save only model weights checkpoint_weights = ModelCheckpoint( filepath='weights_{epoch:02d}.h5', save_weights_only=True, verbose=1 )
2. EarlyStopping - Early Stopping
pythonfrom tensorflow.keras.callbacks import EarlyStopping # Early stopping based on validation loss early_stop = EarlyStopping( monitor='val_loss', patience=5, mode='min', restore_best_weights=True, verbose=1 ) # Early stopping based on validation accuracy early_stop_acc = EarlyStopping( monitor='val_accuracy', patience=3, mode='max', verbose=1 )
3. ReduceLROnPlateau - Learning Rate Decay
pythonfrom tensorflow.keras.callbacks import ReduceLROnPlateau # Reduce learning rate when validation loss stops improving reduce_lr = ReduceLROnPlateau( monitor='val_loss', factor=0.1, # Multiply learning rate by 0.1 patience=3, mode='min', min_lr=1e-7, verbose=1 ) # Adjust learning rate based on accuracy reduce_lr_acc = ReduceLROnPlateau( monitor='val_accuracy', factor=0.5, patience=2, mode='max', verbose=1 )
4. TensorBoard - TensorBoard Logging
pythonfrom tensorflow.keras.callbacks import TensorBoard import datetime # Create log directory with timestamp log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard = TensorBoard( log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, update_freq='epoch' )
5. LearningRateScheduler - Learning Rate Scheduling
pythonfrom tensorflow.keras.callbacks import LearningRateScheduler import math # Define learning rate scheduling function def lr_scheduler(epoch, lr): if epoch < 10: return lr else: return lr * math.exp(-0.1) lr_schedule = LearningRateScheduler(lr_scheduler, verbose=1) # Use predefined learning rate decay def step_decay(epoch): initial_lr = 0.001 drop = 0.5 epochs_drop = 10.0 lrate = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop)) return lrate lr_step = LearningRateScheduler(step_decay, verbose=1)
6. CSVLogger - CSV Logging
pythonfrom tensorflow.keras.callbacks import CSVLogger csv_logger = CSVLogger( 'training.log', separator=',', append=False )
7. ProgbarLogger - Progress Bar Logging
pythonfrom tensorflow.keras.callbacks import ProgbarLogger progbar = ProgbarLogger( count_mode='steps', stateful_metrics=['loss', 'accuracy'] )
8. LambdaCallback - Custom Callback
pythonfrom tensorflow.keras.callbacks import LambdaCallback # Simple custom callback lambda_callback = LambdaCallback( on_epoch_begin=lambda epoch, logs: print(f"Epoch {epoch} started"), on_epoch_end=lambda epoch, logs: print(f"Epoch {epoch} ended, Loss: {logs['loss']:.4f}"), on_batch_begin=lambda batch, logs: None, on_batch_end=lambda batch, logs: None, on_train_begin=lambda logs: print("Training started"), on_train_end=lambda logs: print("Training ended") )
9. RemoteMonitor - Remote Monitoring
pythonfrom tensorflow.keras.callbacks import RemoteMonitor remote_monitor = RemoteMonitor( root='http://localhost:9000', path='/publish/epoch/end/', field='data', headers=None, send_as_json=False )
10. BackupAndRestore - Backup and Restore
pythonfrom tensorflow.keras.callbacks import BackupAndRestore backup_restore = BackupAndRestore( backup_dir='backup', save_freq='epoch', delete_checkpoint=True )
Custom Callbacks
Basic Custom Callback
pythonfrom tensorflow.keras.callbacks import Callback class CustomCallback(Callback): def on_train_begin(self, logs=None): print("Training started") def on_train_end(self, logs=None): print("Training ended") def on_epoch_begin(self, epoch, logs=None): print(f"Epoch {epoch} started") def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch} ended") print(f"Loss: {logs['loss']:.4f}") print(f"Accuracy: {logs['accuracy']:.4f}") def on_batch_begin(self, batch, logs=None): pass def on_batch_end(self, batch, logs=None): if batch % 100 == 0: print(f"Batch {batch}, Loss: {logs['loss']:.4f}")
Using Custom Callback
python# Create custom callback instance custom_callback = CustomCallback() # Use during training model.fit( x_train, y_train, epochs=10, validation_data=(x_val, y_val), callbacks=[custom_callback] )
Advanced Custom Callback Examples
1. Learning Rate Recorder Callback
pythonclass LearningRateRecorder(Callback): def __init__(self): super(LearningRateRecorder, self).__init__() self.lr_history = [] def on_epoch_end(self, epoch, logs=None): # Get current learning rate lr = self.model.optimizer.learning_rate if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule): lr = lr(self.model.optimizer.iterations) self.lr_history.append(float(lr)) print(f"Epoch {epoch}: Learning Rate = {lr:.6f}") def get_lr_history(self): return self.lr_history
2. Gradient Monitor Callback
pythonclass GradientMonitor(Callback): def __init__(self, log_dir='logs/gradients'): super(GradientMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): # Calculate gradients with tf.GradientTape() as tape: predictions = self.model(x_train[:1]) loss = self.model.compiled_loss(y_train[:1], predictions) gradients = tape.gradient(loss, self.model.trainable_variables) # Log gradient norms with self.writer.as_default(): for i, grad in enumerate(gradients): if grad is not None: grad_norm = tf.norm(grad) tf.summary.scalar(f'gradient_norm_{i}', grad_norm, step=epoch)
3. Model Weight Monitor Callback
pythonclass WeightMonitor(Callback): def __init__(self, log_dir='logs/weights'): super(WeightMonitor, self).__init__() self.log_dir = log_dir self.writer = tf.summary.create_file_writer(log_dir) def on_epoch_end(self, epoch, logs=None): with self.writer.as_default(): for i, layer in enumerate(self.model.layers): if hasattr(layer, 'get_weights'): weights = layer.get_weights() for j, w in enumerate(weights): w_mean = tf.reduce_mean(w) w_std = tf.math.reduce_std(w) tf.summary.scalar(f'layer_{i}_weight_{j}_mean', w_mean, step=epoch) tf.summary.scalar(f'layer_{i}_weight_{j}_std', w_std, step=epoch)
4. Custom Early Stopping Callback
pythonclass CustomEarlyStopping(Callback): def __init__(self, monitor='val_loss', patience=5, min_delta=0): super(CustomEarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.best = None self.stopped_epoch = 0 def on_train_begin(self, logs=None): self.wait = 0 self.best = float('inf') if 'loss' in self.monitor else -float('inf') def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) if current is None: return if self.monitor == 'val_loss': if current < self.best - self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}") else: if current > self.best + self.min_delta: self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True print(f"Early stopping at epoch {epoch}")
5. Mixed Precision Training Callback
pythonclass MixedPrecisionCallback(Callback): def __init__(self): super(MixedPrecisionCallback, self).__init__() self.loss_scale = 1.0 def on_batch_end(self, batch, logs=None): # Check if loss is NaN or Inf if logs is not None and 'loss' in logs: if tf.math.is_nan(logs['loss']) or tf.math.is_inf(logs['loss']): print(f"NaN/Inf detected at batch {batch}, reducing loss scale") self.loss_scale /= 2.0 # Reset optimizer state self.model.optimizer.set_weights([ w / 2.0 if w is not None else None for w in self.model.optimizer.get_weights() ])
6. Data Augmentation Callback
pythonclass DataAugmentationCallback(Callback): def __init__(self, augmentation_fn): super(DataAugmentationCallback, self).__init__() self.augmentation_fn = augmentation_fn def on_batch_begin(self, batch, logs=None): # Apply data augmentation during training if self.model.trainable: # Here you can access current batch data # Real application requires more complex implementation pass
7. Model Ensemble Callback
pythonclass ModelEnsembleCallback(Callback): def __init__(self, ensemble_size=5): super(ModelEnsembleCallback, self).__init__() self.ensemble_size = ensemble_size self.models = [] def on_epoch_end(self, epoch, logs=None): # Save model snapshot if epoch % 5 == 0 and len(self.models) < self.ensemble_size: model_copy = tf.keras.models.clone_model(self.model) model_copy.set_weights(self.model.get_weights()) self.models.append(model_copy) print(f"Saved model snapshot at epoch {epoch}") def predict_ensemble(self, x): # Ensemble prediction predictions = [model.predict(x) for model in self.models] return np.mean(predictions, axis=0)
Combining Callbacks
python# Combine multiple callbacks callbacks = [ # Model checkpoint ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, verbose=1 ), # Early stopping EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True, verbose=1 ), # Learning rate decay ReduceLROnPlateau( monitor='val_loss', factor=0.1, patience=3, verbose=1 ), # TensorBoard TensorBoard(log_dir='logs/fit'), # Custom callback CustomCallback() ] # Train model model.fit( x_train, y_train, epochs=100, validation_data=(x_val, y_val), callbacks=callbacks )
Callback Execution Order
The execution order of callbacks is as follows:
on_train_beginon_epoch_beginon_batch_beginon_batch_endon_epoch_endon_train_end
Callback Best Practices
1. Reasonably Set Monitoring Metrics
python# Choose appropriate monitoring metrics based on task early_stop = EarlyStopping( monitor='val_accuracy' if classification else 'val_loss', patience=5, verbose=1 )
2. Save Best Model
python# Always save the best model checkpoint = ModelCheckpoint( 'best_model.h5', monitor='val_loss', save_best_only=True, mode='min' )
3. Use Learning Rate Scheduling
python# Combine learning rate scheduling and early stopping callbacks = [ ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3), EarlyStopping(monitor='val_loss', patience=10) ]
4. Monitor Training Process
python# Use TensorBoard to monitor training tensorboard = TensorBoard( log_dir='logs/fit', histogram_freq=1, write_graph=True )
5. Avoid Excessive Logging
python# Don't log information too frequently class EfficientCallback(Callback): def on_epoch_end(self, epoch, logs=None): if epoch % 5 == 0: # Log every 5 epochs print(f"Epoch {epoch}: Loss = {logs['loss']:.4f}")
6. Handle Exceptions
pythonclass RobustCallback(Callback): def on_batch_end(self, batch, logs=None): try: # Processing logic pass except Exception as e: print(f"Error in callback: {e}") # Don't interrupt training
Callback Application Scenarios
1. Long Training
python# Use checkpoints and backup restore callbacks = [ ModelCheckpoint('checkpoint.h5', save_freq='epoch'), BackupAndRestore(backup_dir='backup') ]
2. Hyperparameter Tuning
python# Use early stopping and learning rate scheduling callbacks = [ EarlyStopping(monitor='val_loss', patience=5), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2) ]
3. Experiment Tracking
python# Use TensorBoard and CSVLogger callbacks = [ TensorBoard(log_dir='logs/experiment_1'), CSVLogger('experiment_1.csv') ]
4. Production Deployment
python# Save best model and monitor performance callbacks = [ ModelCheckpoint('production_model.h5', save_best_only=True), CustomMonitoringCallback() ]
Summary
TensorFlow's callbacks provide powerful training control capabilities:
- Built-in callbacks: Provide common training control functions
- Custom callbacks: Implement specific training logic
- Flexible combination: Can combine multiple callbacks
- Execution order: Understand when callbacks execute
- Best practices: Reasonably use callbacks to improve training efficiency
Mastering callbacks will help you better control and monitor the model training process.