Training API¶
Complete API reference for training operations.
Trainer Class¶
class Trainer:
def __init__(self, agent, data_loader, epochs=100, learning_rate=0.001)
Methods¶
train()¶
Start training process.
history = trainer.train()
validate()¶
Run validation.
metrics = trainer.validate()
evaluate(test_loader)¶
Evaluate on test set.
metrics = trainer.evaluate(test_loader)
save_checkpoint(path)¶
Save model checkpoint.
trainer.save_checkpoint("./checkpoint.pt")
load_checkpoint(path)¶
Load checkpoint.
trainer.load_checkpoint("./checkpoint.pt")
set_scheduler(scheduler)¶
Set learning rate scheduler.
trainer.set_scheduler(scheduler)
add_callback(callback)¶
Add training callback.
trainer.add_callback(custom_callback)
Training Callbacks¶
class TrainingCallback:
def on_train_start(self): pass
def on_epoch_start(self, epoch): pass
def on_batch_end(self, batch, loss): pass
def on_epoch_end(self, epoch, metrics): pass
def on_train_end(self, history): pass