medkit.training
Contents
medkit.training#
APIs#
For accessing these APIs, you may use import like this:
from medkit.training import <api_to_import>
This package needs extra-dependencies not installed as core dependencies of medkit. To install them, use pip install medkit-lib[training].
Classes:
A BatchData pack data allowing both column and row access |
|
Default implementation of |
|
|
A MetricsComputer is the base protocol to compute metrics in training |
|
TrainableComponent is the base protocol to be trainable in medkit |
|
A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations |
A TrainerCallback is the base class for trainer callbacks |
|
|
Trainer configuration |
- class TrainerCallback[source]#
A TrainerCallback is the base class for trainer callbacks
Methods:
on_epoch_begin
(epoch)Event called at the beginning of an epoch
on_epoch_end
(metrics, epoch, epoch_time)Event called at the end of an epoch
on_save
(checkpoint_dir)Event called on saving a checkpoint
on_step_begin
(step_idx, nb_batches, phase)Event called at the beginning of a step in training
on_step_end
(step_idx, nb_batches, phase)Event called at the end of a step in training
on_train_begin
(config)Event called at the beginning of training
Event called at the end of training
- class DefaultPrinterCallback[source]#
Default implementation of
TrainerCallback
Methods:
on_epoch_begin
(epoch)Event called at the beginning of an epoch
on_epoch_end
(metrics, epoch, epoch_duration)Event called at the end of an epoch
on_save
(checkpoint_dir)Event called on saving a checkpoint
on_step_begin
(step_idx, nb_batches, phase)Event called at the beginning of a step in training
on_step_end
(step_idx, nb_batches, phase)Event called at the end of a step in training
on_train_begin
(config)Event called at the beginning of training
Event called at the end of training
- on_step_begin(step_idx, nb_batches, phase)[source]#
Event called at the beginning of a step in training
- on_epoch_begin(epoch)#
Event called at the beginning of an epoch
- class Trainer(component, config, train_data, eval_data, metrics_computer=None, lr_scheduler_builder=None, callback=None)[source]#
A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations
- Parameters
component (
TrainableComponent
) – The component to train, the component must implement the TrainableComponent protocol.config (
TrainerConfig
) – A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpointstrain_data (
Any
) – The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.eval_data (
Any
) – The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.metrics_computer (
Optional
[MetricsComputer
]) – Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.lr_scheduler_builder (
Optional
[Callable
[[Optimizer
],Any
]]) – Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.callback (
Optional
[TrainerCallback
]) – Optional callback to customize training.
Methods:
evaluation_epoch
(eval_dataloader)Perform an epoch using the evaluation data.
get_dataloader
(data, shuffle)Return a DataLoader with transformations defined in the component to train
make_forward_pass
(inputs, eval_mode)Run forward safely, same device as the component
save
(epoch)Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)
train
()Main training method.
Perform an epoch using the training data.
update_learning_rate
(eval_metrics)Call the learning rate scheduler if defined
- get_dataloader(data, shuffle)[source]#
Return a DataLoader with transformations defined in the component to train
- Return type
DataLoader
- training_epoch()[source]#
Perform an epoch using the training data.
When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.
Return a dictionary with metrics.
- Return type
Dict
[str
,float
]
- evaluation_epoch(eval_dataloader)[source]#
Perform an epoch using the evaluation data.
The additional metrics are prepared per batch. Return a dictionary with metrics.
- Return type
Dict
[str
,float
]
- make_forward_pass(inputs, eval_mode)[source]#
Run forward safely, same device as the component
- Return type
Tuple
[BatchData
,Tensor
]
- class TrainerConfig(output_dir, learning_rate=1e-05, nb_training_epochs=3, dataloader_nb_workers=0, batch_size=1, seed=None, gradient_accumulation_steps=1, do_metrics_in_training=False, metric_to_track_lr='loss', checkpoint_period=1, checkpoint_metric='loss', minimize_checkpoint_metric=True)[source]#
Trainer configuration
- Parameters
output_dir (str) – The output directory where the checkpoint will be saved.
learning_rate (float) – The initial learning rate.
nb_training_epochs (int) – Total number of training/evaluation epochs to do.
dataloader_nb_workers (int) – Number of subprocess for the data loading. The default value is 0, the data will be loaded in the main process. If this config is for a HuggingFace model, do not change this value.
batch_size (int) – Number of samples per batch to load.
seed (Optional[int]) – Random seed to use with PyTorch and numpy. It should be set to ensure reproducibility between experiments.
gradient_accumulation_steps (int) – Number of steps to accumulate gradient before performing an optimization step.
do_metrics_in_training (bool) – By default, only the custom metrics are computed using eval_data. If set to True, the custom metrics are computed also using training_data.
metric_to_track_lr (str) – Name of the eval metric to be tracked for updating the learning rate. By default, eval loss is tracked.
checkpoint_period (int) – How often, in number of epochs, should we save a checkpoint. Use 0 to only save last checkpoint.
checkpoint_metric (str) – Name of the eval metric to be tracked for selecting the best checkpoint. By default, eval loss is tracked.
minimize_checkpoint_metric (bool) – If True, the checkpoint with the lowest metric value will be selected as best, otherwise the checkpoint with the highest metric value.
- class BatchData[source]#
A BatchData pack data allowing both column and row access
Methods:
to_device
(device)Ensure that Tensors in the BatchData object are on the specified device
- to_device(device)[source]#
Ensure that Tensors in the BatchData object are on the specified device
- Parameters
device (
device
) – A torch.device object representing the device on which tensors will be allocated.- Return type
- Returns
BatchData – A new object with the tensors on the proper device.
- clear() None. Remove all items from D. #
- copy() a shallow copy of D #
- fromkeys(value=None, /)#
Create a new dictionary with keys from iterable and values set to value.
- get(key, default=None, /)#
Return the value for key if key is in the dictionary, else default.
- items() a set-like object providing a view on D's items #
- keys() a set-like object providing a view on D's keys #
- pop(k[, d]) v, remove specified key and return the corresponding value. #
If key is not found, d is returned if given, otherwise KeyError is raised
- popitem()#
Remove and return a (key, value) pair as a 2-tuple.
Pairs are returned in LIFO (last-in, first-out) order. Raises KeyError if the dict is empty.
- setdefault(key, default=None, /)#
Insert key with a value of default if key is not in the dictionary.
Return the value for key if key is in the dictionary, else default.
- update([E, ]**F) None. Update D from dict/iterable E and F. #
If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
- values() an object providing a view on D's values #
- class MetricsComputer(*args, **kwargs)[source]#
A MetricsComputer is the base protocol to compute metrics in training
Methods:
compute
(all_data)Compute metrics using 'all_data'
prepare_batch
(model_output, input_batch)Prepare a batch of data to compute the metrics
- compute(all_data)[source]#
Compute metrics using ‘all_data’
- Parameters
all_data (Dict[str, List[Any]]) – A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.
- Return type
Dict
[str
,float
]- Returns
Dict[str, float] – A dictionary with the results
- class TrainableComponent(*args, **kwargs)[source]#
TrainableComponent is the base protocol to be trainable in medkit
Methods:
collate
(batch)Collate a list of data processed by preprocess to form a batch
Create optimizer using the learning rate
forward
(input_batch, return_loss, eval_mode)Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.
load
(path)Load weights from disk
preprocess
(data_item)Preprocess the input data item and return a dictionary with everything needed for the forward pass.
save
(path)Save model to disk
- preprocess(data_item)[source]#
Preprocess the input data item and return a dictionary with everything needed for the forward pass.
This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.
- Return type
Dict
[str
,Any
]
- forward(input_batch, return_loss, eval_mode)[source]#
Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.
Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval()
- Return type
Tuple
[BatchData
,Optional
[Tensor
]]