ONNX Runtime
Loading...
Searching...
No Matches
Ort::TrainingSession Class Reference

Trainer class that provides training, evaluation and optimizer methods for training an ONNX models. More...

#include <onnxruntime_training_cxx_api.h>

Inheritance diagram for Ort::TrainingSession:
Ort::detail::Base< OrtTrainingSession >

Public Member Functions

Constructing the Training Session
 TrainingSession (const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt)
 Create a training session that can be used to begin or resume training.
 
 TrainingSession (const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::vector< uint8_t > &train_model_data, const std::vector< uint8_t > &eval_model_data={}, const std::vector< uint8_t > &optim_model_data={})
 Create a training session that can be used to begin or resume training. This constructor allows the users to load the models from buffers instead of files.
 
Implementing The Training Loop
std::vector< ValueTrainStep (const std::vector< Value > &input_values)
 Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.
 
void LazyResetGrad ()
 Reset the gradients of all trainable parameters to zero lazily.
 
std::vector< ValueEvalStep (const std::vector< Value > &input_values)
 Computes the outputs for the eval model for the given inputs.
 
void SetLearningRate (float learning_rate)
 Sets the learning rate for this training session.
 
float GetLearningRate () const
 Gets the current learning rate for this training session.
 
void RegisterLinearLRScheduler (int64_t warmup_step_count, int64_t total_step_count, float initial_lr)
 Registers a linear learning rate scheduler for the training session.
 
void SchedulerStep ()
 Update the learning rate based on the registered learing rate scheduler.
 
void OptimizerStep ()
 Performs the weight updates for the trainable parameters using the optimizer model.
 
Prepare For Inferencing
void ExportModelForInferencing (const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names)
 Export a model that can be used for inferencing.
 
Model IO Information
std::vector< std::string > InputNames (const bool training)
 Retrieves the names of the user inputs for the training and eval models.
 
std::vector< std::string > OutputNames (const bool training)
 Retrieves the names of the user outputs for the training and eval models.
 
Accessing The Training Session State
Value ToBuffer (const bool only_trainable)
 Returns a contiguous buffer that holds a copy of all training state parameters.
 
void FromBuffer (Value &buffer)
 Loads the training session model parameters from a contiguous buffer.
 
- Public Member Functions inherited from Ort::detail::Base< OrtTrainingSession >
constexpr Base ()=default
 
constexpr Base (contained_type *p) noexcept
 
 Base (const Base &)=delete
 
 Base (Base &&v) noexcept
 
 ~Base ()
 
Baseoperator= (const Base &)=delete
 
Baseoperator= (Base &&v) noexcept
 
constexpr operator contained_type * () const noexcept
 
contained_typerelease ()
 Relinquishes ownership of the contained C object pointer The underlying object is not destroyed.
 

Additional Inherited Members

- Public Types inherited from Ort::detail::Base< OrtTrainingSession >
using contained_type = OrtTrainingSession
 
- Protected Attributes inherited from Ort::detail::Base< OrtTrainingSession >
contained_typep_
 

Detailed Description

Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.

The training session requires four training artifacts

  • The training onnx model
  • The evaluation onnx model (optional)
  • The optimizer onnx model
  • The checkpoint file

These artifacts can be generated using the onnxruntime-training python utility.

Constructor & Destructor Documentation

◆ TrainingSession() [1/2]

Ort::TrainingSession::TrainingSession ( const Env env,
const SessionOptions session_options,
CheckpointState checkpoint_state,
const std::basic_string< char > &  train_model_path,
const std::optional< std::basic_string< char > > &  eval_model_path = std::nullopt,
const std::optional< std::basic_string< char > > &  optimizer_model_path = std::nullopt 
)

Create a training session that can be used to begin or resume training.

This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).

Parameters
[in]envEnv to be used for the training session.
[in]session_optionsSessionOptions that the user can customize for this training session.
[in]checkpoint_stateTraining states that the training session uses as a starting point for training.
[in]train_model_pathModel to be used to perform training.
[in]eval_model_pathModel to be used to perform evaluation.
[in]optimizer_model_pathModel to be used to perform gradient descent.

◆ TrainingSession() [2/2]

Ort::TrainingSession::TrainingSession ( const Env env,
const SessionOptions session_options,
CheckpointState checkpoint_state,
const std::vector< uint8_t > &  train_model_data,
const std::vector< uint8_t > &  eval_model_data = {},
const std::vector< uint8_t > &  optim_model_data = {} 
)

Create a training session that can be used to begin or resume training. This constructor allows the users to load the models from buffers instead of files.

Parameters
[in]envEnv to be used for the training session.
[in]session_optionsSessionOptions that the user can customize for this training session.
[in]checkpoint_stateTraining states that the training session uses as a starting point for training.
[in]train_model_dataBuffer containing training model data.
[in]eval_model_dataBuffer containing evaluation model data.
[in]optim_model_dataBuffer containing optimizer model (used for performing weight/parameter update).

Member Function Documentation

◆ EvalStep()

std::vector< Value > Ort::TrainingSession::EvalStep ( const std::vector< Value > &  input_values)

Computes the outputs for the eval model for the given inputs.

This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.

Parameters
[in]input_valuesThe user inputs to the eval model.
Returns
A std::vector of Ort::Value objects that represents the output of the eval pass.

◆ ExportModelForInferencing()

void Ort::TrainingSession::ExportModelForInferencing ( const std::basic_string< char > &  inference_model_path,
const std::vector< std::string > &  graph_output_names 
)

Export a model that can be used for inferencing.

If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with Ort::Session.

Note
Note that the function re-loads the eval model from the path provided to Ort::TrainingSession and expects that this path still be valid.
Parameters
[in]inference_model_pathPath where the inference model should be serialized to.
[in]graph_output_namesNames of the outputs that are needed in the inference model.

◆ FromBuffer()

void Ort::TrainingSession::FromBuffer ( Value buffer)

Loads the training session model parameters from a contiguous buffer.

In case the training session was created with a nominal checkpoint, invoking this function is required to load the updated parameters onto the checkpoint to complete it.

Parameters
[in]bufferContiguous buffer to load the parameters from.

◆ GetLearningRate()

float Ort::TrainingSession::GetLearningRate ( ) const

Gets the current learning rate for this training session.

This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.

Returns
float representing the current learning rate.

◆ InputNames()

std::vector< std::string > Ort::TrainingSession::InputNames ( const bool  training)

Retrieves the names of the user inputs for the training and eval models.

This function returns the names of inputs of the training or eval model that can be associated with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep function.

Parameters
[in]trainingWhether the training model input names are requested or eval model input names.
Returns
Graph input names for either the training model or the eval model.

◆ LazyResetGrad()

void Ort::TrainingSession::LazyResetGrad ( )

Reset the gradients of all trainable parameters to zero lazily.

This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next Ort::TrainingSession::TrainStep.

◆ OptimizerStep()

void Ort::TrainingSession::OptimizerStep ( )

Performs the weight updates for the trainable parameters using the optimizer model.

This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next Ort::TrainingSession::TrainStep function call.

◆ OutputNames()

std::vector< std::string > Ort::TrainingSession::OutputNames ( const bool  training)

Retrieves the names of the user outputs for the training and eval models.

This function returns the names of outputs of the training or eval model that can be associated with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep function.

Parameters
[in]trainingWhether the training model output names are requested or eval model output names.
Returns
Graph output names for either the training model or the eval model.

◆ RegisterLinearLRScheduler()

void Ort::TrainingSession::RegisterLinearLRScheduler ( int64_t  warmup_step_count,
int64_t  total_step_count,
float  initial_lr 
)

Registers a linear learning rate scheduler for the training session.

Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.

Parameters
[in]warmup_step_countWarmup steps for LR warmup.
[in]total_step_countTotal step count.
[in]initial_lrThe initial learning rate to be used by the training session.

◆ SchedulerStep()

void Ort::TrainingSession::SchedulerStep ( )

Update the learning rate based on the registered learing rate scheduler.

Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.

Note
Please note that a valid predefined learning rate scheduler must be first registered to invoke this function.

◆ SetLearningRate()

void Ort::TrainingSession::SetLearningRate ( float  learning_rate)

Sets the learning rate for this training session.

This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.

Note
Please note that this function does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler.
Parameters
[in]learning_rateDesired learning rate to be set.

◆ ToBuffer()

Value Ort::TrainingSession::ToBuffer ( const bool  only_trainable)

Returns a contiguous buffer that holds a copy of all training state parameters.

Parameters
[in]only_trainableWhether to only copy trainable parameters or to copy all parameters.
Returns
Contiguous buffer to the model parameters.

◆ TrainStep()

std::vector< Value > Ort::TrainingSession::TrainStep ( const std::vector< Value > &  input_values)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs.

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the Ort::TrainingSession::OptimizerStep function. The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.

Parameters
[in]input_valuesThe user inputs to the training model.
Returns
A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.