Skip to content

Training API

Complete API reference for training operations.

Trainer Class

class Trainer:
    def __init__(self, agent, data_loader, epochs=100, learning_rate=0.001)

Methods

train()

Start training process.

history = trainer.train()

validate()

Run validation.

metrics = trainer.validate()

evaluate(test_loader)

Evaluate on test set.

metrics = trainer.evaluate(test_loader)

save_checkpoint(path)

Save model checkpoint.

trainer.save_checkpoint("./checkpoint.pt")

load_checkpoint(path)

Load checkpoint.

trainer.load_checkpoint("./checkpoint.pt")

set_scheduler(scheduler)

Set learning rate scheduler.

trainer.set_scheduler(scheduler)

add_callback(callback)

Add training callback.

trainer.add_callback(custom_callback)

Training Callbacks

class TrainingCallback:
    def on_train_start(self): pass
    def on_epoch_start(self, epoch): pass
    def on_batch_end(self, batch, loss): pass
    def on_epoch_end(self, epoch, metrics): pass
    def on_train_end(self, history): pass

See Also