medkit.training.trainer
medkit.training.trainer#
Classes:
|
A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations |
- class Trainer(component, config, train_data, eval_data, metrics_computer=None, lr_scheduler_builder=None, callback=None)[source]#
A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations
- Parameters
component (
TrainableComponent
) – The component to train, the component must implement the TrainableComponent protocol.config (
TrainerConfig
) – A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpointstrain_data (
Any
) – The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.eval_data (
Any
) – The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.metrics_computer (
Optional
[MetricsComputer
]) – Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.lr_scheduler_builder (
Optional
[Callable
[[Optimizer
],Any
]]) – Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.callback (
Optional
[TrainerCallback
]) – Optional callback to customize training.
Methods:
evaluation_epoch
(eval_dataloader)Perform an epoch using the evaluation data.
get_dataloader
(data, shuffle)Return a DataLoader with transformations defined in the component to train
make_forward_pass
(inputs, eval_mode)Run forward safely, same device as the component
save
(epoch)Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)
train
()Main training method.
Perform an epoch using the training data.
update_learning_rate
(eval_metrics)Call the learning rate scheduler if defined
- get_dataloader(data, shuffle)[source]#
Return a DataLoader with transformations defined in the component to train
- Return type
DataLoader
- training_epoch()[source]#
Perform an epoch using the training data.
When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.
Return a dictionary with metrics.
- Return type
Dict
[str
,float
]
- evaluation_epoch(eval_dataloader)[source]#
Perform an epoch using the evaluation data.
The additional metrics are prepared per batch. Return a dictionary with metrics.
- Return type
Dict
[str
,float
]
- make_forward_pass(inputs, eval_mode)[source]#
Run forward safely, same device as the component
- Return type
Tuple
[BatchData
,Tensor
]