Fine-tuning a Transformers model with medkit
Contents
Fine-tuning a Transformers model with medkit#
Note
This example may require optional modules from medkit, use the following to install them:
pip install medkit-lib[training,hf-entity-matcher]
In recent years, Large Language Models (LLMs) have achieved very good performance in natural language processing (NLP) tasks. However, training a LLM (involving billions of parameters) from scratch requires a lot of resources and large quantities of text.
Since these models are trained on general domain data, they learn complex patterns. We can adapt (fine-tune) the last layers to a specific task using our data and low resources. LLMs are PreTrained and accessible with libraries like 🤗 Transformers. Medkit has some components to fine-tune these models.
Prepare DrBert for entity recognition#
In this example, we show how to fine-tune DrBERT: A PreTrained model in French for Biomedical and Clinical domains to detect the following entities: problem
, treatment
, test
using the medkit Trainer.
DrBert1 is a French RoBERTa trained in open source corpus of french medical documents for masked language modeling. As mentioned before, we can change the specific task, for example, to classify entities.
Using a custom medkit dataset#
Let’s start by defining a dataset using medkit documents. For this example, we use CorpusCASM2
, an internal corpus with clinical cases annotated by master students. The corpus contains more than 5000 medkit documents (~ phrases) with entities to detect. The splits are predefined so, all we need to do is use the path of the desired split (train
or validation
) to load the documents.
Tip
You can test this tutorial with your data. You can create medkit documents, add entities and export them to JSONL files.
from medkit.core.text import TextDocument, Entity, Span
from medkit.io.medkit_json import save_text_documents
document = TextDocument(
"Your custom phrase with entities",
anns=[Entity(label="CUSTOM", spans=[Span(24, 32)], text="entities")],
)
# save your list of documents
train_docs = [document]
save_text_documents(train_docs, output_file="./train.jsonl")
You may refer to medkit_json
for more information.
from torch.utils.data import Dataset
from medkit.io.medkit_json import load_text_documents
class CorpusCASM2(Dataset):
"""A dataset of clinical cases from the CORPUS CAS(medkit--version)"""
def __init__(self, split):
print(f"Creating CorpusCASM2 corpus {split}")
self.labels_set = ["treatment", "problem", "test"]
data_path = f"{split}.jsonl"
self.documents = [doc for doc in load_text_documents(data_path)]
def __getitem__(self, idx):
return self.documents[idx]
def __len__(self):
return len(self.documents)
Just to see how a document looks, let’s print the first example from the test dataset.
doc = CorpusCASM2(split="test")[0]
msg = "|".join(f"'{entity.label}':{entity.text}" for entity in doc.anns.entities)
print(f"Text: '{doc.text}'\n{msg}")
Text: 'Une tachycardie et une fibrillation ventriculaire ont été observées.'
'problem':tachycardie |'problem': fibrillation ventriculaire
We can now define the datasets to use with the trainer.
train_dataset = CorpusCASM2(split="train")
val_dataset = CorpusCASM2(split="validation")
Creating an entity matcher trainable#
Once documents have been collected, we need a component that implements the TrainableComponent
protocol.
Medkit supports Entity detection with HuggingFace models in inference and fine-tune context. The HFEntityMatcher
expose its trainable version as a ready-to-use component. It defines the preprocessing, forward and its optimizer.
See also
More info about this component in HFEntityMatcherTrainable
Let’s define a trainable instance for this example.
from medkit.text.ner.hf_entity_matcher import HFEntityMatcher
hf_config = dict(
model_name_or_path="Dr-BERT/DrBERT-4GB-CP-PubMedBERT", # name in HF hub
labels=["problem", "treatment", "test"], # labels to fine-tune
tokenizer_max_length=128, # max length per item
tagging_scheme="iob2", # scheme to tag documents
tag_subtokens=False, # only tag the first token by word
)
hf_trainable = HFEntityMatcher.make_trainable(**hf_config)
Fine-tuning with medkit Trainer#
At this point, we have prepared the data and the component to fine-tune. All we need to do is define the trainer with its configuration.
from medkit.training import Trainer,TrainerConfig
trainer_config = TrainerConfig(
output_dir="DrBert-CASM2", # output directory
batch_size=4,
do_metrics_in_training=False,
learning_rate=5e-6,
nb_training_epochs=5,
seed=0,
)
trainer = Trainer(
component=hf_trainable, # trainable component
config=trainer_config, # configuration
train_data=train_dataset, # training documents
eval_data=val_dataset, # eval documents
)
history = trainer.train()
Training history#
The trainer has a callback to display basic training information like loss
, time
and metrics
if required, the method trainer.train()
returns a dictionary with the training history and saves a checkpoint with the tuned model.
An example of log:
2023-05-03 21:13:07,304 - DefaultPrinterCallback - INFO - Training metrics : loss: 0.219
2023-05-03 21:13:07,305 - DefaultPrinterCallback - INFO - Evaluation metrics : loss: 0.20|
2023-05-03 21:13:07,306 - DefaultPrinterCallback - INFO - Epoch state: |epoch_id: 5 | time: 2348.17s
2023-05-03 21:13:07,307 - DefaultPrinterCallback - INFO - Saving checkpoint in DrBert-CASM2/checkpoint_03-05-2023_21:13
Adding metrics in training#
By default, only the loss configured by the trainable component is computed during the training / evaluation loop. We can add more metrics using a class that implements MetricsComputer
. For entity detection, we can instantiate SeqEvalMetricsComputer
directly. This object process and compute the metrics during training (using PyTorch Tensors).
from medkit.text.metrics.ner import SeqEvalMetricsComputer
mc_seqeval = SeqEvalMetricsComputer(
id_to_label=hf_trainable.id_to_label, # mapping int value to tag
tagging_scheme=hf_trainable.tagging_scheme, # tagging scheme to compute
return_metrics_by_label= True, # include metrics by label in results
)
Warning
The Trainer updates the trainable component (~ model’s weights) during training, if you want to run a new experiment, you need to create a new instance of the trainable component.
Running with metrics
Note
By default, the Trainer only computes custom metrics using eval data. You can set do_metrics_in_training=True
in the trainer configuration to also compute custom metrics using training data.
trainer_with_metrics = Trainer(
component=HFEntityMatcher.make_trainable(**hf_config), # a new instance
config=trainer_config, # configuration
train_data=train_dataset, # training documents
eval_data=val_dataset, # eval documents
metrics_computer=mc_seqeval
)
history_with_metrics = trainer.train()
Custom metrics are in history_with_metrics
and the logs looks like this:
2023-05-04 20:33:59,128 - DefaultPrinterCallback - INFO - Training metrics : loss: 0.227
2023-05-04 20:33:59,129 - DefaultPrinterCallback - INFO - Evaluation metrics : loss: 0.286|macro_precision: 0.626|macro_recall: 0.722|macro_f1-score: 0.670|support:3542.000|accuracy: 0.899|problem_precision: 0.609|problem_recall: 0.690|problem_f1-score: 0.647|problem_support:1812.000|test_precision: 0.667|test_recall: 0.780|test_f1-score: 0.719|test_support: 937.000|treatment_precision: 0.614|treatment_recall: 0.728|treatment_f1-score: 0.666|treatment_support: 793.000
Detecting entities in inference#
Now we have a entity matcher fine-tuned with our custom dataset. We can use the last checkpoint to define a HFEntityMatcher
and detect problem
, treatment
, test
entities in french documents.
Hint
In this version, the trainer saves one checkpoint at the end of training. The path will be {trainer_config.output_path}/checkpoint_{DATETIME_END_TRAINING}
from medkit.core.text import TextDocument
from medkit.text.ner.hf_entity_matcher import HFEntityMatcher
matcher = HFEntityMatcher(model="./DrBert-CASM2/checkpoint_03-05-2023_21:13")
test_doc = TextDocument("Elle souffre d'asthme mais n'a pas besoin d'Allegra")
# detect entities in the raw segment
detected_entities = matcher.run([test_doc.raw_segment])
msg = "|".join(f"'{entity.label}':{entity.text}" for entity in detected_entities)
print(f"Text: '{test_doc.text}'\n{msg}")
Text: "Elle souffre d'asthme mais n'a pas besoin d'Allegra"
'problem':asthme|'treatment':Allegra
References
- 1
Yanis Labrak, Adrien Bazoge, Richard Dufour, Mickael Rouvier, Emmanuel Morin, BĂ©atrice Daille, and Pierre-Antoine Gourraud. (2023). DrBERT: A Robust Pre-trained Model in French for Biomedical and Clinical domains.