medkit.training.utils#

Classes:

BatchData

A BatchData pack data allowing both column and row access

MetricsComputer(*args, **kwargs)

A MetricsComputer is the base protocol to compute metrics in training

class BatchData[source]#

A BatchData pack data allowing both column and row access

Methods:

to_device(device)

Ensure that Tensors in the BatchData object are on the specified device

to_device(device)[source]#

Ensure that Tensors in the BatchData object are on the specified device

Parameters

device (device) – A torch.device object representing the device on which tensors will be allocated.

Return type

BatchData

Returns

BatchData – A new object with the tensors on the proper device.

class MetricsComputer(*args, **kwargs)[source]#

A MetricsComputer is the base protocol to compute metrics in training

Methods:

compute(all_data)

Compute metrics using 'all_data'

prepare_batch(model_output, input_batch)

Prepare a batch of data to compute the metrics

prepare_batch(model_output, input_batch)[source]#

Prepare a batch of data to compute the metrics

Parameters
  • model_output (BatchData) – Output data after a model forward pass.

  • input_batch (BatchData) – Preprocessed input batch

Return type

Dict[str, List[Any]]

Returns

Dict[str, List[Any]] – A dictionary with the required data to calculate the metric

compute(all_data)[source]#

Compute metrics using ‘all_data’

Parameters

all_data (Dict[str, List[Any]]) – A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.

Return type

Dict[str, float]

Returns

Dict[str, float] – A dictionary with the results