medkit.training#

APIs#

For accessing these APIs, you may use import like this:

from medkit.training import <api_to_import>

This package needs extra-dependencies not installed as core dependencies of medkit. To install them, use pip install medkit-lib[training].

Classes:

BatchData

A BatchData pack data allowing both column and row access

DefaultPrinterCallback()

Default implementation of TrainerCallback

MetricsComputer(*args, **kwargs)

A MetricsComputer is the base protocol to compute metrics in training

TrainableComponent(*args, **kwargs)

TrainableComponent is the base protocol to be trainable in medkit

Trainer(component, config, train_data, eval_data)

A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations

TrainerCallback()

A TrainerCallback is the base class for trainer callbacks

TrainerConfig(output_dir[, learning_rate, ...])

Trainer configuration

class TrainerCallback[source]#

A TrainerCallback is the base class for trainer callbacks

Methods:

on_epoch_begin(epoch)

Event called at the beginning of an epoch

on_epoch_end(metrics, epoch, epoch_time)

Event called at the end of an epoch

on_save(checkpoint_dir)

Event called on saving a checkpoint

on_step_begin(step_idx, nb_batches, phase)

Event called at the beginning of a step in training

on_step_end(step_idx, nb_batches, phase)

Event called at the end of a step in training

on_train_begin(config)

Event called at the beginning of training

on_train_end()

Event called at the end of training

on_train_begin(config)[source]#

Event called at the beginning of training

on_train_end()[source]#

Event called at the end of training

on_epoch_begin(epoch)[source]#

Event called at the beginning of an epoch

on_epoch_end(metrics, epoch, epoch_time)[source]#

Event called at the end of an epoch

on_step_begin(step_idx, nb_batches, phase)[source]#

Event called at the beginning of a step in training

on_step_end(step_idx, nb_batches, phase)[source]#

Event called at the end of a step in training

on_save(checkpoint_dir)[source]#

Event called on saving a checkpoint

class DefaultPrinterCallback[source]#

Default implementation of TrainerCallback

Methods:

on_epoch_begin(epoch)

Event called at the beginning of an epoch

on_epoch_end(metrics, epoch, epoch_duration)

Event called at the end of an epoch

on_save(checkpoint_dir)

Event called on saving a checkpoint

on_step_begin(step_idx, nb_batches, phase)

Event called at the beginning of a step in training

on_step_end(step_idx, nb_batches, phase)

Event called at the end of a step in training

on_train_begin(config)

Event called at the beginning of training

on_train_end()

Event called at the end of training

on_train_begin(config)[source]#

Event called at the beginning of training

on_epoch_end(metrics, epoch, epoch_duration)[source]#

Event called at the end of an epoch

on_train_end()[source]#

Event called at the end of training

on_save(checkpoint_dir)[source]#

Event called on saving a checkpoint

on_step_begin(step_idx, nb_batches, phase)[source]#

Event called at the beginning of a step in training

on_step_end(step_idx, nb_batches, phase)[source]#

Event called at the end of a step in training

on_epoch_begin(epoch)#

Event called at the beginning of an epoch

class Trainer(component, config, train_data, eval_data, metrics_computer=None, lr_scheduler_builder=None, callback=None)[source]#

A trainer is a base training/eval loop for a TrainableComponent that uses PyTorch models to create medkit annotations

Parameters
  • component (TrainableComponent) – The component to train, the component must implement the TrainableComponent protocol.

  • config (TrainerConfig) – A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints

  • train_data (Any) – The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.

  • eval_data (Any) – The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.

  • metrics_computer (Optional[MetricsComputer]) – Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.

  • lr_scheduler_builder (Optional[Callable[[Optimizer], Any]]) – Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.

  • callback (Optional[TrainerCallback]) – Optional callback to customize training.

Methods:

evaluation_epoch(eval_dataloader)

Perform an epoch using the evaluation data.

get_dataloader(data, shuffle)

Return a DataLoader with transformations defined in the component to train

make_forward_pass(inputs, eval_mode)

Run forward safely, same device as the component

save(epoch)

Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)

train()

Main training method.

training_epoch()

Perform an epoch using the training data.

update_learning_rate(eval_metrics)

Call the learning rate scheduler if defined

get_dataloader(data, shuffle)[source]#

Return a DataLoader with transformations defined in the component to train

Return type

DataLoader

training_epoch()[source]#

Perform an epoch using the training data.

When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.

Return a dictionary with metrics.

Return type

Dict[str, float]

evaluation_epoch(eval_dataloader)[source]#

Perform an epoch using the evaluation data.

The additional metrics are prepared per batch. Return a dictionary with metrics.

Return type

Dict[str, float]

make_forward_pass(inputs, eval_mode)[source]#

Run forward safely, same device as the component

Return type

Tuple[BatchData, Tensor]

update_learning_rate(eval_metrics)[source]#

Call the learning rate scheduler if defined

train()[source]#

Main training method. Call the training / eval loop.

Return a list with the metrics per epoch.

Return type

List[Dict]

save(epoch)[source]#

Save a checkpoint (trainer configuration, model weights, optimizer and scheduler)

Parameters

epoch (int) – Epoch corresponding of the current training state (will be included in the checkpoint name)

Return type

str

Returns

Path – Path of the checkpoint saved

class TrainerConfig(output_dir, learning_rate=1e-05, nb_training_epochs=3, dataloader_nb_workers=0, batch_size=1, seed=None, gradient_accumulation_steps=1, do_metrics_in_training=False, metric_to_track_lr='loss', checkpoint_period=1, checkpoint_metric='loss', minimize_checkpoint_metric=True)[source]#

Trainer configuration

Parameters
  • output_dir (str) – The output directory where the checkpoint will be saved.

  • learning_rate (float) – The initial learning rate.

  • nb_training_epochs (int) – Total number of training/evaluation epochs to do.

  • dataloader_nb_workers (int) – Number of subprocess for the data loading. The default value is 0, the data will be loaded in the main process. If this config is for a HuggingFace model, do not change this value.

  • batch_size (int) – Number of samples per batch to load.

  • seed (Optional[int]) – Random seed to use with PyTorch and numpy. It should be set to ensure reproducibility between experiments.

  • gradient_accumulation_steps (int) – Number of steps to accumulate gradient before performing an optimization step.

  • do_metrics_in_training (bool) – By default, only the custom metrics are computed using eval_data. If set to True, the custom metrics are computed also using training_data.

  • metric_to_track_lr (str) – Name of the eval metric to be tracked for updating the learning rate. By default, eval loss is tracked.

  • checkpoint_period (int) – How often, in number of epochs, should we save a checkpoint. Use 0 to only save last checkpoint.

  • checkpoint_metric (str) – Name of the eval metric to be tracked for selecting the best checkpoint. By default, eval loss is tracked.

  • minimize_checkpoint_metric (bool) – If True, the checkpoint with the lowest metric value will be selected as best, otherwise the checkpoint with the highest metric value.

class BatchData[source]#

A BatchData pack data allowing both column and row access

Methods:

to_device(device)

Ensure that Tensors in the BatchData object are on the specified device

to_device(device)[source]#

Ensure that Tensors in the BatchData object are on the specified device

Parameters

device (device) – A torch.device object representing the device on which tensors will be allocated.

Return type

BatchData

Returns

BatchData – A new object with the tensors on the proper device.

clear() None.  Remove all items from D.#
copy() a shallow copy of D#
fromkeys(value=None, /)#

Create a new dictionary with keys from iterable and values set to value.

get(key, default=None, /)#

Return the value for key if key is in the dictionary, else default.

items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
pop(k[, d]) v, remove specified key and return the corresponding value.#

If key is not found, d is returned if given, otherwise KeyError is raised

popitem()#

Remove and return a (key, value) pair as a 2-tuple.

Pairs are returned in LIFO (last-in, first-out) order. Raises KeyError if the dict is empty.

setdefault(key, default=None, /)#

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

update([E, ]**F) None.  Update D from dict/iterable E and F.#

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

values() an object providing a view on D's values#
class MetricsComputer(*args, **kwargs)[source]#

A MetricsComputer is the base protocol to compute metrics in training

Methods:

compute(all_data)

Compute metrics using 'all_data'

prepare_batch(model_output, input_batch)

Prepare a batch of data to compute the metrics

prepare_batch(model_output, input_batch)[source]#

Prepare a batch of data to compute the metrics

Parameters
  • model_output (BatchData) – Output data after a model forward pass.

  • input_batch (BatchData) – Preprocessed input batch

Return type

Dict[str, List[Any]]

Returns

Dict[str, List[Any]] – A dictionary with the required data to calculate the metric

compute(all_data)[source]#

Compute metrics using ‘all_data’

Parameters

all_data (Dict[str, List[Any]]) – A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.

Return type

Dict[str, float]

Returns

Dict[str, float] – A dictionary with the results

class TrainableComponent(*args, **kwargs)[source]#

TrainableComponent is the base protocol to be trainable in medkit

Methods:

collate(batch)

Collate a list of data processed by preprocess to form a batch

configure_optimizer(lr)

Create optimizer using the learning rate

forward(input_batch, return_loss, eval_mode)

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

load(path)

Load weights from disk

preprocess(data_item)

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

save(path)

Save model to disk

configure_optimizer(lr)[source]#

Create optimizer using the learning rate

Return type

Optimizer

preprocess(data_item)[source]#

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.

Return type

Dict[str, Any]

collate(batch)[source]#

Collate a list of data processed by preprocess to form a batch

Return type

BatchData

forward(input_batch, return_loss, eval_mode)[source]#

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval()

Return type

Tuple[BatchData, Optional[Tensor]]

save(path)[source]#

Save model to disk

load(path)[source]#

Load weights from disk

Subpackages / Submodules#

medkit.training.callbacks

medkit.training.trainable_component

medkit.training.trainer

medkit.training.trainer_config

medkit.training.utils