medkit.training.trainer#

Classes:

Trainer(component, config, train_data, eval_data)

A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations

class Trainer(component, config, train_data, eval_data, metrics_computer=None, lr_scheduler_builder=None, callback=None)[source]#

A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations

Parameters
  • component (TrainableComponent) – The component to train, the component must implement the TrainableComponent protocol.

  • config (TrainerConfig) – A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints

  • train_data (Any) – The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.

  • eval_data (Any) – The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.

  • metrics_computer (Optional[MetricsComputer]) – Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.

  • lr_scheduler_builder (Optional[Callable[[Optimizer], Any]]) – Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.

  • callback (Optional[TrainerCallback]) – Optional callback to customize training.

Methods:

evaluation_epoch(eval_dataloader)

Perform an epoch using the evaluation data.

get_dataloader(data, shuffle)

Return a DataLoader with transformations defined in the component to train

make_forward_pass(inputs, eval_mode)

Run forward safely, same device as the component

save(epoch)

Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)

train()

Main training method.

training_epoch()

Perform an epoch using the training data.

update_learning_rate(eval_metrics)

Call the learning rate scheduler if defined

get_dataloader(data, shuffle)[source]#

Return a DataLoader with transformations defined in the component to train

Return type

DataLoader

training_epoch()[source]#

Perform an epoch using the training data.

When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.

Return a dictionary with metrics.

Return type

Dict[str, float]

evaluation_epoch(eval_dataloader)[source]#

Perform an epoch using the evaluation data.

The additional metrics are prepared per batch. Return a dictionary with metrics.

Return type

Dict[str, float]

make_forward_pass(inputs, eval_mode)[source]#

Run forward safely, same device as the component

Return type

Tuple[BatchData, Tensor]

update_learning_rate(eval_metrics)[source]#

Call the learning rate scheduler if defined

train()[source]#

Main training method. Call the training / eval loop.

Return a list with the metrics per epoch.

Return type

List[Dict]

save(epoch)[source]#

Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)

Parameters

epoch (int) – Epoch corresponding of the current training state (will be included in the checkpoint name)

Return type

str

Returns

Path – Path of the checkpoint saved