Show / Hide Table of Contents

Class TrainingSession

Trainer class that provides training, evaluation and optimizer methods for training an ONNX model.

The training session requires four training artifacts

  • The training onnx model
  • The evaluation onnx model (optional)
  • The optimizer onnx model
  • The checkpoint directory

These artifacts can be generated using the onnxruntime-training python utility.

This is an IDisposable class and it must be disposed of using either an explicit call to Dispose() method or a pattern of using() block. If this is a member of another class that class must also become IDisposable and it must dispose of TrainingSession in its Dispose() method.

Inheritance
Object
TrainingSession
Implements
IDisposable
Inherited Members
Object.Equals(Object)
Object.Equals(Object, Object)
Object.GetHashCode()
Object.GetType()
Object.MemberwiseClone()
Object.ReferenceEquals(Object, Object)
Object.ToString()
Namespace: Microsoft.ML.OnnxRuntime
Assembly: Microsoft.ML.OnnxRuntime.dll
Syntax
public class TrainingSession : IDisposable

Constructors

| Improve this Doc View Source

TrainingSession(CheckpointState, String, String)

Create a training session that can be used to begin or resume training.

This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).

Declaration
public TrainingSession(CheckpointState state, string trainModelPath, string optimizerModelPath)
Parameters
Type Name Description
CheckpointState state

Training states that the training session uses as a starting point for training.

String trainModelPath

Model to be used to perform training.

String optimizerModelPath

Model to be used to perform weight update.

| Improve this Doc View Source

TrainingSession(CheckpointState, String, String, String)

Create a training session that can be used to begin or resume training.

This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).

Declaration
public TrainingSession(CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
Parameters
Type Name Description
CheckpointState state

Training states that the training session uses as a starting point for training.

String trainModelPath

Model to be used to perform training.

String evalModelPath

Model to be used to perform evaluation.

String optimizerModelPath

Model to be used to perform weight update.

| Improve this Doc View Source

TrainingSession(SessionOptions, CheckpointState, String, String, String)

Create a training session that can be used to begin or resume training.

This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).

Declaration
public TrainingSession(SessionOptions options, CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
Parameters
Type Name Description
SessionOptions options

SessionOptions that the user can customize for this training session.

CheckpointState state

Training states that the training session uses as a starting point for training.

String trainModelPath

Model to be used to perform training.

String evalModelPath

Model to be used to perform evaluation.

String optimizerModelPath

Model to be used to perform weight update.

Methods

| Improve this Doc View Source

Dispose()

IDisposable implementation

Declaration
public void Dispose()
| Improve this Doc View Source

Dispose(Boolean)

IDisposable implementation

Declaration
protected virtual void Dispose(bool disposing)
Parameters
Type Name Description
Boolean disposing

true if invoked from Dispose() method

| Improve this Doc View Source

EvalStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs for the eval model for the given inputs This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.

Declaration
public void EvalStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type Name Description
RunOptions options

Specify RunOptions for step.

IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the eval model.

IReadOnlyCollection<FixedBufferOnnxValue> outputValues

Specify a collection of FixedBufferOnnxValue that indicates the output values of the eval model.

| Improve this Doc View Source

EvalStep(IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs for the eval model for the given inputs This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.

Declaration
public void EvalStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type Name Description
IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the eval model.

IReadOnlyCollection<FixedBufferOnnxValue> outputValues

Specify a collection of FixedBufferOnnxValue that indicates the output values of the eval model.

| Improve this Doc View Source

EvalStep(IReadOnlyCollection<OrtValue>)

This function performs an eval step that computes the outputs of the eval model for the given inputs. Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was provided to the training session. Example usage:

using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
List{OrtValue} inputValues = new List{OrtValue} { x, label };
using (var loss = trainingSession.EvalSteps(inputValues))
{
    // process output values
}
Declaration
public IDisposableReadOnlyCollection<OrtValue> EvalStep(IReadOnlyCollection<OrtValue> inputValues)
Parameters
Type Name Description
IReadOnlyCollection<OrtValue> inputValues

Specify a collection of OrtValue that indicates the input values to the eval model.

Returns
Type Description
IDisposableReadOnlyCollection<OrtValue>
| Improve this Doc View Source

ExportModelForInferencing(String, IReadOnlyCollection<String>)

Export a model that can be used for inferencing. If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with InferenceSession. Note that the function re-loads the eval model from the path provided to TrainingSession and expects that this path still be valid.

Declaration
public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection<string> graphOutputNames)
Parameters
Type Name Description
String inferenceModelPath

Path where the inference model should be serialized to.

IReadOnlyCollection<String> graphOutputNames

Names of the outputs that are needed in the inference model.

| Improve this Doc View Source

Finalize()

Finalizer.

Declaration
protected void Finalize()
| Improve this Doc View Source

FromBuffer(OrtValue, Boolean)

Loads the training session model parameters from a contiguous buffer

Declaration
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
Parameters
Type Name Description
OrtValue ortValue

Contiguous buffer to load the parameters from.

Boolean onlyTrainable

Whether to only load trainable parameters or to load all parameters.

| Improve this Doc View Source

GetLearningRate()

Gets the current learning rate for this training session.

This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.

Declaration
public float GetLearningRate()
Returns
Type Description
Single

float representing the current learning rate.

| Improve this Doc View Source

InputNames(Boolean)

Retrieves the names of the user inputs for the training and eval models.

Declaration
public List<string> InputNames(bool training)
Parameters
Type Name Description
Boolean training

Whether the training model input names are requested or eval model input names.

Returns
Type Description
List<String>
| Improve this Doc View Source

LazyResetGrad()

Reset the gradients of all trainable parameters to zero lazily.

This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next TrainStep.

Declaration
public void LazyResetGrad()
| Improve this Doc View Source

OptimizerStep()

Performs the weight updates for the trainable parameters using the optimizer model.

This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next TrainStep function call.

Declaration
public void OptimizerStep()
| Improve this Doc View Source

OptimizerStep(RunOptions)

Performs the weight updates for the trainable parameters using the optimizer model.

This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next TrainStep function call.

Declaration
public void OptimizerStep(RunOptions options)
Parameters
Type Name Description
RunOptions options

Specify RunOptions for step.

| Improve this Doc View Source

OutputNames(Boolean)

Retrieves the names of the user outputs for the training and eval models.

Declaration
public List<string> OutputNames(bool training)
Parameters
Type Name Description
Boolean training

Whether the training model output names are requested or eval model output names.

Returns
Type Description
List<String>
| Improve this Doc View Source

RegisterLinearLRScheduler(Int64, Int64, Single)

Registers a linear learning rate scheduler for the training session.

Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.

Declaration
public void RegisterLinearLRScheduler(long warmupStepCount, long totalStepCount, float initialLearningRate)
Parameters
Type Name Description
Int64 warmupStepCount

Number of warmup steps

Int64 totalStepCount

Number of total steps

Single initialLearningRate

Initial learning rate

| Improve this Doc View Source

SchedulerStep()

Update the learning rate based on the registered learning rate scheduler.

Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.

note

Please note that a valid predefined learning rate scheduler must be first registered to invoke this function.

Declaration
public void SchedulerStep()
| Improve this Doc View Source

SetLearningRate(Single)

Sets the learning rate for this training session.

This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.

note

Please note that this function does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, please look at the function RegisterLinearLRScheduler.

Declaration
public void SetLearningRate(float learningRate)
Parameters
Type Name Description
Single learningRate

Desired learning rate to be set.

| Improve this Doc View Source

ToBuffer(Boolean)

Returns a contiguous buffer that holds a copy of all training state parameters

Declaration
public OrtValue ToBuffer(bool onlyTrainable)
Parameters
Type Name Description
Boolean onlyTrainable

Whether to only copy trainable parameters or to copy all parameters.

Returns
Type Description
OrtValue
| Improve this Doc View Source

TrainStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.

Declaration
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
Parameters
Type Name Description
RunOptions options

Specify RunOptions for step.

IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model.

Returns
Type Description
IDisposableReadOnlyCollection<DisposableNamedOnnxValue>

Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.

| Improve this Doc View Source

TrainStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.

Declaration
public void TrainStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type Name Description
RunOptions options

Specify RunOptions for step.

IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model.

IReadOnlyCollection<FixedBufferOnnxValue> outputValues

Specify a collection of FixedBufferOnnxValue that indicates the output values of the training model.

| Improve this Doc View Source

TrainStep(IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.

Declaration
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
Parameters
Type Name Description
IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model.

Returns
Type Description
IDisposableReadOnlyCollection<DisposableNamedOnnxValue>

Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.

| Improve this Doc View Source

TrainStep(IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)

Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.

Declaration
public void TrainStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type Name Description
IReadOnlyCollection<FixedBufferOnnxValue> inputValues

Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model.

IReadOnlyCollection<FixedBufferOnnxValue> outputValues

Specify a collection of FixedBufferOnnxValue that indicates the output values of the training model.

| Improve this Doc View Source

TrainStep(IReadOnlyCollection<OrtValue>)

This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function. Example usage:

using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
List{OrtValue} inputValues = new List{OrtValue} { x, label };
using (var loss = trainingSession.TrainStep(inputValues))
{
    // process output values
}
Declaration
public IDisposableReadOnlyCollection<OrtValue> TrainStep(IReadOnlyCollection<OrtValue> inputValues)
Parameters
Type Name Description
IReadOnlyCollection<OrtValue> inputValues

Specify a collection of OrtValue that indicates the input values to the training model.

Returns
Type Description
IDisposableReadOnlyCollection<OrtValue>

Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.

Implements

System.IDisposable
  • Improve this Doc
  • View Source
In This Article
Back to top