Training module
Contents
Training module#
This page describes all components related to medkit training.
Important
For using this module, you need to install PyTorch.
You may install additional dependencies using
pip install medkit-lib[training]
.
Note
For more details about all sub-packages, refer to medkit.training
.
Be trainable in medkit#
A component can implement the TrainableComponent
protocol to be trainable in medkit. With this protocol, you can define how to preprocess data, call the model and define the optimizer. Then, the Trainer
will use these methods inside the training / evaluation loop.
The following table explains who makes the calls and where they make them:
Who |
Where |
A TrainableComponent |
---|---|---|
TrainableComponent |
Initialization |
load : load/initialize modules to be trained |
Trainer |
Initialization |
create_optimizer : define an optimizer for the training/evalution loop |
Data loading |
preproces: transform medkit anns to input data |
|
Forward step |
forward: call internal model, return loss and model output |
|
Saving checkpoint |
save: save trained modules |
A trainable component to detect entities#
A trainable component could define how to train a model from scratch or fine-tune a pretrained model. As a first implementation, medkit includes HFEntityMatcherTrainable
, a trainable version of HFEntityMatcher
. As you can see, an operation can contains a trainable component and expose it using the make_trainable()
method.
You may see this tutorial with a fine-tune case for entity detection.
Important
Currently, medkit only supports the training of components using PyTorch components.
Note
For more details, refer to medkit.training.trainable_component
module.
Trainer#
The Trainer
aims to train any component implementing the TrainableComponent
protocol. For each step involving data transformation, the Trainer calls the corresponding methods in the TrainableComponent.
For example, if you want to train a SegmentClassifier, you can define how to preprocess the Segment
with its Attribute
to get a dictionary of tensors for the model. Under the hood, the training loop will call SegmentClassifier.preprocess()
+ SegmentClassifier.collate()
inside the training_dataloader
to transform the medkit segments into a Batch of tensors.
# 1. Initialize the trainable component i.e. a segment_classifier
segment_classifier = SegmentClassifier(...)
# 2. Load/prepare the set of medkit anns (segments)
# 3. Define hyperparameters for the trainer
trainer_config = TrainerConfig(...)
trainer = Trainer(
component=segment_classifier, # trainable component
config=trainer_config, # configuration
train_data=train_dataset, # training documents
eval_data=val_dataset, # eval documents
)
History#
Once the trainer has been configured, you can start the training using trainer.train()
. The method returns a dictionary with the metrics during training and evaluation by epoch.
history = trainer.train()
# ...
# log main information, metrics and info about the checkpoint
The trainer controls the calling of methods and optional modules, here a simplified version of the training loop.
for input_data in training_dataloader:
callback_on_step()
input_data = input_data.to_device(device)
output_data, loss = trainableComponent.forward(input_data)
loss.backward()
optimizer.step()
# if metrics_computer is defined
data_for_metrics.extend(metrics_computer.prepare_batch(input_data,output_data))
...
# compute metrics
metrics_computer.compute(data_for_metrics)
Note
For more details, refer to medkit.training.trainer
module.
Custom training#
Hyperparameters#
The TrainerConfig
allows you to define learning parameters such as learning rate, number of epochs, etc.
Metrics Computer#
You can add custom metrics in training. You can define how prepare a batch for the metric and how to compute the metric. For more details, refer to medkit.training.MetricsComputer
protocol.
Tip
For the moment, medkit includes SeqEvalMetricsComputer
for entity detection. This is still in development, you can integrate more metrics depending on your task/modality.
Learning rate scheduler#
You can define how to adjust learning rate. If you use PyTorch models, you can use a method from torch.optim.lr_scheduler
For example, you can update the learning rate each 5 optimization steps:
import torch
trainer = Trainer(
..., lr_scheduler_builder=lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
)
If you use transformers models, you may refer to get_scheduler
method.
Callbacks#
medkit provides a set of callbacks that can be used if you want to do some stuff like logging information.
For using these callbacks, you need to implement a class derived from TrainerCallback
.
If you do not provide your own one to the Trainer
, it will
use the DefaultPrinterCallback
.
Note
For more details, refer to medkit.training.callbacks
module.
Note
This module is under development, in future versions medkit could support more powerful callbacks.