medkit.training.trainable_component#

Classes:

TrainableComponent(*args, **kwargs)

TrainableComponent is the base protocol to be trainable in medkit

class TrainableComponent(*args, **kwargs)[source]#

TrainableComponent is the base protocol to be trainable in medkit

Methods:

collate(batch)

Collate a list of data processed by preprocess to form a batch

configure_optimizer(lr)

Create optimizer using the learning rate

forward(input_batch, return_loss, eval_mode)

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

load(path)

Load weights from disk

preprocess(data_item)

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

save(path)

Save model to disk

configure_optimizer(lr)[source]#

Create optimizer using the learning rate

Return type

Optimizer

preprocess(data_item)[source]#

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.

Return type

Dict[str, Any]

collate(batch)[source]#

Collate a list of data processed by preprocess to form a batch

Return type

BatchData

forward(input_batch, return_loss, eval_mode)[source]#

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval()

Return type

Tuple[BatchData, Optional[Tensor]]

save(path)[source]#

Save model to disk

load(path)[source]#

Load weights from disk