Class TrainingSession
Trainer class that provides training, evaluation and optimizer methods for training an ONNX model.
The training session requires four training artifacts
- The training onnx model
- The evaluation onnx model (optional)
- The optimizer onnx model
- The checkpoint directory
These artifacts can be generated using the onnxruntime-training
python utility.
This is an IDisposable class and it must be disposed of using either an explicit call to Dispose() method or a pattern of using() block. If this is a member of another class that class must also become IDisposable and it must dispose of TrainingSession in its Dispose() method.
Inheritance
Implements
Inherited Members
Namespace: Microsoft.ML.OnnxRuntime
Assembly: Microsoft.ML.OnnxRuntime.dll
Syntax
public class TrainingSession : IDisposable
Constructors
| Improve this Doc View SourceTrainingSession(CheckpointState, String, String)
Create a training session that can be used to begin or resume training.
This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).
Declaration
public TrainingSession(CheckpointState state, string trainModelPath, string optimizerModelPath)
Parameters
Type | Name | Description |
---|---|---|
CheckpointState | state | Training states that the training session uses as a starting point for training. |
String | trainModelPath | Model to be used to perform training. |
String | optimizerModelPath | Model to be used to perform weight update. |
TrainingSession(CheckpointState, String, String, String)
Create a training session that can be used to begin or resume training.
This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).
Declaration
public TrainingSession(CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
Parameters
Type | Name | Description |
---|---|---|
CheckpointState | state | Training states that the training session uses as a starting point for training. |
String | trainModelPath | Model to be used to perform training. |
String | evalModelPath | Model to be used to perform evaluation. |
String | optimizerModelPath | Model to be used to perform weight update. |
TrainingSession(SessionOptions, CheckpointState, String, String, String)
Create a training session that can be used to begin or resume training.
This constructor instantiates the training session based on the env and session options provided that can begin or resume training from a given checkpoint state for the given onnx models. The checkpoint state represents the parameters of the training session which will be moved to the device specified by the user through the session options (if necessary).
Declaration
public TrainingSession(SessionOptions options, CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
Parameters
Type | Name | Description |
---|---|---|
SessionOptions | options | SessionOptions that the user can customize for this training session. |
CheckpointState | state | Training states that the training session uses as a starting point for training. |
String | trainModelPath | Model to be used to perform training. |
String | evalModelPath | Model to be used to perform evaluation. |
String | optimizerModelPath | Model to be used to perform weight update. |
Methods
| Improve this Doc View SourceDispose()
IDisposable implementation
Declaration
public void Dispose()
Dispose(Boolean)
IDisposable implementation
Declaration
protected virtual void Dispose(bool disposing)
Parameters
Type | Name | Description |
---|---|---|
Boolean | disposing | true if invoked from Dispose() method |
EvalStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs for the eval model for the given inputs This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.
Declaration
public void EvalStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type | Name | Description |
---|---|---|
RunOptions | options | Specify RunOptions for step. |
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the eval model. |
IReadOnlyCollection<FixedBufferOnnxValue> | outputValues | Specify a collection of FixedBufferOnnxValue that indicates the output values of the eval model. |
EvalStep(IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs for the eval model for the given inputs This function performs an eval step that computes the outputs of the eval model for the given inputs. The eval step is performed based on the eval model that was provided to the training session.
Declaration
public void EvalStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type | Name | Description |
---|---|---|
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the eval model. |
IReadOnlyCollection<FixedBufferOnnxValue> | outputValues | Specify a collection of FixedBufferOnnxValue that indicates the output values of the eval model. |
EvalStep(IReadOnlyCollection<OrtValue>)
This function performs an eval step that computes the outputs of the eval model for the given inputs. Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was provided to the training session. Example usage:
using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
List{OrtValue} inputValues = new List{OrtValue} { x, label };
using (var loss = trainingSession.EvalSteps(inputValues))
{
// process output values
}
Declaration
public IDisposableReadOnlyCollection<OrtValue> EvalStep(IReadOnlyCollection<OrtValue> inputValues)
Parameters
Type | Name | Description |
---|---|---|
IReadOnlyCollection<OrtValue> | inputValues | Specify a collection of OrtValue that indicates the input values to the eval model. |
Returns
Type | Description |
---|---|
IDisposableReadOnlyCollection<OrtValue> |
ExportModelForInferencing(String, IReadOnlyCollection<String>)
Export a model that can be used for inferencing. If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model's outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with InferenceSession. Note that the function re-loads the eval model from the path provided to TrainingSession and expects that this path still be valid.
Declaration
public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollection<string> graphOutputNames)
Parameters
Type | Name | Description |
---|---|---|
String | inferenceModelPath | Path where the inference model should be serialized to. |
IReadOnlyCollection<String> | graphOutputNames | Names of the outputs that are needed in the inference model. |
Finalize()
Finalizer.
Declaration
protected void Finalize()
FromBuffer(OrtValue, Boolean)
Loads the training session model parameters from a contiguous buffer
Declaration
public void FromBuffer(OrtValue ortValue, bool onlyTrainable)
Parameters
Type | Name | Description |
---|---|---|
OrtValue | ortValue | Contiguous buffer to load the parameters from. |
Boolean | onlyTrainable | Whether to only load trainable parameters or to load all parameters. |
GetLearningRate()
Gets the current learning rate for this training session.
This function allows users to get the learning rate for the training session. The current learning rate is maintained by the training session, and users can query it for the purpose of implementing their own learning rate schedulers.
Declaration
public float GetLearningRate()
Returns
Type | Description |
---|---|
Single | float representing the current learning rate. |
InputNames(Boolean)
Retrieves the names of the user inputs for the training and eval models.
Declaration
public List<string> InputNames(bool training)
Parameters
Type | Name | Description |
---|---|---|
Boolean | training | Whether the training model input names are requested or eval model input names. |
Returns
Type | Description |
---|---|
List<String> |
LazyResetGrad()
Reset the gradients of all trainable parameters to zero lazily.
This function sets the internal state of the training session such that the gradients of the trainable parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are computed on the next invocation of the next TrainStep.
Declaration
public void LazyResetGrad()
OptimizerStep()
Performs the weight updates for the trainable parameters using the optimizer model.
This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next TrainStep function call.
Declaration
public void OptimizerStep()
OptimizerStep(RunOptions)
Performs the weight updates for the trainable parameters using the optimizer model.
This function performs the weight update step that updates the trainable parameters such that they take a step in the direction of their gradients (gradient descent). The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next TrainStep function call.
Declaration
public void OptimizerStep(RunOptions options)
Parameters
Type | Name | Description |
---|---|---|
RunOptions | options | Specify RunOptions for step. |
OutputNames(Boolean)
Retrieves the names of the user outputs for the training and eval models.
Declaration
public List<string> OutputNames(bool training)
Parameters
Type | Name | Description |
---|---|---|
Boolean | training | Whether the training model output names are requested or eval model output names. |
Returns
Type | Description |
---|---|
List<String> |
RegisterLinearLRScheduler(Int64, Int64, Single)
Registers a linear learning rate scheduler for the training session.
Register a linear learning rate scheduler that decays the learning rate by linearly updated multiplicative factor from the initial learning rate set on the training session to 0. The decay is performed after the initial warm up phase where the learning rate is linearly incremented from 0 to the initial learning rate provided.
Declaration
public void RegisterLinearLRScheduler(long warmupStepCount, long totalStepCount, float initialLearningRate)
Parameters
Type | Name | Description |
---|---|---|
Int64 | warmupStepCount | Number of warmup steps |
Int64 | totalStepCount | Number of total steps |
Single | initialLearningRate | Initial learning rate |
SchedulerStep()
Update the learning rate based on the registered learning rate scheduler.
Takes a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as determined necessary to update the learning rate being used by the training session.
note
Please note that a valid predefined learning rate scheduler must be first registered to invoke this function.
Declaration
public void SchedulerStep()
SetLearningRate(Single)
Sets the learning rate for this training session.
This function allows users to set the learning rate for the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this function with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.
note
Please note that this function does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, please look at the function RegisterLinearLRScheduler.
Declaration
public void SetLearningRate(float learningRate)
Parameters
Type | Name | Description |
---|---|---|
Single | learningRate | Desired learning rate to be set. |
ToBuffer(Boolean)
Returns a contiguous buffer that holds a copy of all training state parameters
Declaration
public OrtValue ToBuffer(bool onlyTrainable)
Parameters
Type | Name | Description |
---|---|---|
Boolean | onlyTrainable | Whether to only copy trainable parameters or to copy all parameters. |
Returns
Type | Description |
---|---|
OrtValue |
TrainStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.
Declaration
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
Parameters
Type | Name | Description |
---|---|---|
RunOptions | options | Specify RunOptions for step. |
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model. |
Returns
Type | Description |
---|---|
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> | Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. |
TrainStep(RunOptions, IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.
Declaration
public void TrainStep(RunOptions options, IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type | Name | Description |
---|---|---|
RunOptions | options | Specify RunOptions for step. |
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model. |
IReadOnlyCollection<FixedBufferOnnxValue> | outputValues | Specify a collection of FixedBufferOnnxValue that indicates the output values of the training model. |
TrainStep(IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.
Declaration
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
Parameters
Type | Name | Description |
---|---|---|
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model. |
Returns
Type | Description |
---|---|
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> | Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. |
TrainStep(IReadOnlyCollection<FixedBufferOnnxValue>, IReadOnlyCollection<FixedBufferOnnxValue>)
Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function.
Declaration
public void TrainStep(IReadOnlyCollection<FixedBufferOnnxValue> inputValues, IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
Parameters
Type | Name | Description |
---|---|---|
IReadOnlyCollection<FixedBufferOnnxValue> | inputValues | Specify a collection of FixedBufferOnnxValue that indicates the input values to the training model. |
IReadOnlyCollection<FixedBufferOnnxValue> | outputValues | Specify a collection of FixedBufferOnnxValue that indicates the output values of the training model. |
TrainStep(IReadOnlyCollection<OrtValue>)
This function performs a training step that computes the outputs of the training model and the gradients of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model that was provided to the training session. The TrainStep method is equivalent of running forward propagation and backward propagation in a single step. The gradients computed are stored inside the training session state so they can be later consumed by the OptimizerStep function. The gradients can be lazily reset by invoking the LazyResetGrad function. Example usage:
using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
List{OrtValue} inputValues = new List{OrtValue} { x, label };
using (var loss = trainingSession.TrainStep(inputValues))
{
// process output values
}
Declaration
public IDisposableReadOnlyCollection<OrtValue> TrainStep(IReadOnlyCollection<OrtValue> inputValues)
Parameters
Type | Name | Description |
---|---|---|
IReadOnlyCollection<OrtValue> | inputValues | Specify a collection of OrtValue that indicates the input values to the training model. |
Returns
Type | Description |
---|---|
IDisposableReadOnlyCollection<OrtValue> | Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. |