Train the Model on the Device#

Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.

The expected training artifacts are:

  1. The training onnx model

  2. The checkpoint state

  3. The optimizer onnx model

  4. The eval onnx model (optional)

Sample usage:

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#

Bases: object

Class that represents a model parameter

This class represents a model parameter and provides access to its data, gradient and other properties. This class is not expected to be instantiated directly. Instead, it is returned by the CheckpointState object.

Parameters:
  • parameter – The C.Parameter object that holds the underlying parameter data.

  • state – The C.CheckpointState object that holds the underlying session state.

property name: str#

The name of the parameter

property data: ndarray#

The data of the parameter

property grad: ndarray#

The gradient of the parameter

property requires_grad: bool#

Whether or not the parameter requires its gradient to be computed

__repr__() str[source]#

Returns a string representation of the parameter

class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#

Bases: object

Class that holds all the model parameters

This class holds all the model parameters and provides access to them. This class is not expected to be instantiated directly. Instead, it is returned by the CheckpointState’s parameters attribute. This class behaves like a dictionary and provides access to the parameters by name.

Parameters:

state – The C.CheckpointState object that holds the underlying session state.

__getitem__(name: str) Parameter[source]#

Gets the parameter associated with the given name

Searches for the name in the parameters of the checkpoint state.

Parameters:

name – The name of the parameter

Returns:

The value of the parameter

Raises:

KeyError – If the parameter is not found

__setitem__(name: str, value: ndarray) None[source]#

Sets the parameter value for the given name

Searches for the name in the parameters of the checkpoint state. If the name is found in parameters, the value is updated.

Parameters:
  • name – The name of the parameter

  • value – The value of the parameter as a numpy array

Raises:

KeyError – If the parameter is not found

__contains__(name: str) bool[source]#

Checks if the parameter exists in the state

Parameters:

name – The name of the parameter

Returns:

True if the name is a parameter False otherwise

__iter__()[source]#

Returns an iterator over the properties

__repr__() str[source]#

Returns a string representation of the parameters

__len__() int[source]#

Returns the number of parameters

class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#

Bases: object

__getitem__(name: str) int | float | str[source]#

Gets the property associated with the given name

Searches for the name in the properties of the checkpoint state.

Parameters:

name – The name of the property

Returns:

The value of the property

Raises:

KeyError – If the property is not found

__setitem__(name: str, value: int | float | str) None[source]#

Sets the property value for the given name

Searches for the name in the properties of the checkpoint state. The value is added or updated in the properties.

Parameters:
  • name – The name of the property

  • value – The value of the property Properties only support int, float and str values.

__contains__(name: str) bool[source]#

Checks if the property exists in the state

Parameters:

name – The name of the property

Returns:

True if the name is a property, False otherwise

__iter__()[source]#

Returns an iterator over the properties

__repr__() str[source]#

Returns a string representation of the properties

__len__() int[source]#

Returns the number of properties

class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#

Bases: object

Class that holds the state of the training session

This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties.

To create the CheckpointState, use the CheckpointState.load_checkpoint method.

Parameters:

state – The C.Checkpoint state object that holds the underlying session state.

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#

Loads the checkpoint state from the checkpoint file

Parameters:

checkpoint_uri – The path to the checkpoint file.

Returns:

The checkpoint state object.

Return type:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#

Saves the checkpoint state to the checkpoint file

Parameters:
  • state – The checkpoint state object.

  • checkpoint_uri – The path to the checkpoint file.

  • include_optimizer_state – If True, the optimizer state is also saved to the checkpoint file.

property parameters: Parameters#

Returns the model parameters from the checkpoint state

property properties: Properties#

Returns the properties from the checkpoint state

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#

Bases: object

Trainer class that provides training and evaluation methods for ONNX models.

Before instantiating the Module class, it is expected that the training artifacts have been generated using the onnxruntime.training.artifacts.generate_artifacts utility.

The training artifacts include:
  • The training model

  • The evaluation model (optional)

  • The optimizer model (optional)

  • The checkpoint file

training#

True if the model is in training mode, False if it is in evaluation mode.

Type:

bool

Parameters:
  • train_model_uri – The path to the training model.

  • state – The checkpoint state object.

  • eval_model_uri – The path to the evaluation model.

  • device – The device to run the model on. Default is “cpu”.

  • session_options – The session options to use for the model.

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#

Invokes either the training or the evaluation step of the model.

Parameters:

*user_inputs – The inputs to the model. The user inputs can be either numpy arrays or OrtValues.

Returns:

The outputs of the model.

train(mode: bool = True) Module[source]#

Sets the Module in training mode.

Parameters:

mode – whether to set the model to training mode (True) or evaluation mode (False). Default: True.

Returns:

self

eval() Module[source]#

Sets the Module in evaluation mode.

Returns:

self

lazy_reset_grad()[source]#

Lazily resets the training gradients.

This function sets the internal state of the module such that the module gradients will be scheduled to be reset just before the new gradients are computed on the next invocation of train().

get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#

Creates a contiguous buffer of the training session parameters

Parameters:

trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.

Returns:

The contiguous buffer of the training session parameters.

get_parameters_size(trainable_only: bool = True) int[source]#

Returns the size of the parameters.

Parameters:

trainable_only – If True, only trainable parameters are considered. Otherwise, all parameters are considered.

Returns:

The number of primitive (example floating point) elements in the parameters.

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

Copies the OrtValue buffer to the training session parameters.

Parameters:

buffer – The OrtValue buffer to copy to the training session parameters.

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

Exports the model for inferencing.

Once training is complete, this function can be used to drop the training specific nodes in the onnx model. In particular, this function does the following:

  • Parse over the training graph and identify nodes that generate the given output names.

  • Drop all subsequent nodes in the graph since they are not relevant to the inference graph.

Parameters:
  • inference_model_uri – The path to the inference model.

  • graph_output_names – The list of output names that are required for inferencing.

input_names() list[str][source]#

Returns the input names of the training or eval model.

output_names() list[str][source]#

Returns the output names of the training or eval model.

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

Bases: object

Class that provides methods to update the model parameters based on the computed gradients.

Parameters:
  • optimizer_uri – The path to the optimizer model.

  • model – The module to be trained.

step() None[source]#

Updates the model parameters based on the computed gradients.

This method updates the model parameters by taking a step in the direction of the computed gradients. The optimizer used depends on the optimizer model provided.

set_learning_rate(learning_rate: float) None[source]#

Sets the learning rate for the optimizer.

Parameters:

learning_rate – The learning rate to be set.

get_learning_rate() float[source]#

Gets the current learning rate of the optimizer.

Returns:

The current learning rate.

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

Bases: object

Linearly updates the learning rate in the optimizer

The linear learning rate scheduler decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.

Parameters:
  • optimizer – User’s onnxruntime training Optimizer

  • warmup_step_count – The number of steps in the warm up phase.

  • total_step_count – The total number of training steps.

  • initial_lr – The initial learning rate.

step() None[source]#

Updates the learning rate of the optimizer linearly.

This method should be called at each step of training to ensure that the learning rate is properly adjusted.