ORTTrainingSession

Objective-C

@interface ORTTrainingSession : NSObject

Swift

class ORTTrainingSession : NSObject

Trainer class that provides methods to train, evaluate and optimize ONNX models.

The training session requires four training artifacts:

  1. Training onnx model
  2. Evaluation onnx model (optional)
  3. Optimizer onnx model
  4. Checkpoint directory

onnxruntime-training python utility can be used to generate above training artifacts.

Available since 1.16.

Note

This class is only available when the training APIs are enabled.
  • Unavailable

    Declaration

    Objective-C

    - (instancetype)init NS_UNAVAILABLE;
  • Creates a training session from the training artifacts that can be used to begin or resume training.

    The initializer instantiates the training session based on provided env and session options, which can be used to begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training session which will be moved to the device specified in the session option if needed.

    Note

    Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The training session keeps a strong (owning) pointer to the checkpoint state.

    Declaration

    Objective-C

    - (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env
                          sessionOptions:
                              (nullable ORTSessionOptions *)sessionOptions
                              checkpoint:(nonnull ORTCheckpoint *)checkpoint
                          trainModelPath:(nonnull NSString *)trainModelPath
                           evalModelPath:(nullable NSString *)evalModelPath
                      optimizerModelPath:(nullable NSString *)optimizerModelPath
                                   error:(NSError *_Nullable *_Nullable)error;

    Swift

    init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws

    Parameters

    env

    The ORTEnv instance to use for the training session.

    sessionOptions

    The optional ORTSessionOptions to use for the training session.

    checkpoint

    Training states that are used as a starting point for training.

    trainModelPath

    The path to the training onnx model.

    evalModelPath

    The path to the evaluation onnx model.

    optimizerModelPath

    The path to the optimizer onnx model used to perform gradient descent.

    error

    Optional error information set if an error occurs.

    Return Value

    The instance, or nil if an error occurs.

  • Performs a training step, which is equivalent to a forward and backward propagation in a single step.

    The training step computes the outputs of the training model and the gradients of the trainable parameters for the given input values. The train step is performed based on the training model that was provided to the training session. It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside the training session state so they can be later consumed by optimizerStep. The gradients can be lazily reset by calling lazyResetGrad method.

    Declaration

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        trainStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                           error:(NSError *_Nullable *_Nullable)error;

    Swift

    func trainStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    Parameters

    inputs

    The input values to the training model.

    error

    Optional error information set if an error occurs.

    Return Value

    The output values of the training model.

  • Performs a evaluation step that computes the outputs of the evaluation model for the given inputs. The eval step is performed based on the evaluation model that was provided to the training session.

    Declaration

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        evalStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func evalStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    Parameters

    inputs

    The input values to the eval model.

    error

    Optional error information set if an error occurs.

    Return Value

    The output values of the eval model.

  • Reset the gradients of all trainable parameters to zero lazily.

    Calling this method sets the internal state of the training session such that the gradients of the trainable parameters in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next invocation of the trainStep method.

    Declaration

    Objective-C

    - (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func lazyResetGrad() throws

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the gradients are set to reset successfully, NO otherwise.

  • Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next trainStep method call.

    Declaration

    Objective-C

    - (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func optimizerStep() throws

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the optimizer step was performed successfully, NO otherwise.

  • Returns the names of the user inputs for the training model that can be associated with the ORTValue provided to the trainStep.

    Declaration

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainInputNames() throws -> [String]

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    The names of the user inputs for the training model.

  • Returns the names of the user inputs for the evaluation model that can be associated with the ORTValue provided to the evalStep.

    Declaration

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalInputNames() throws -> [String]

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    The names of the user inputs for the evaluation model.

  • Returns the names of the user outputs for the training model that can be associated with the ORTValue returned by the trainStep.

    Declaration

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainOutputNames() throws -> [String]

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    The names of the user outputs for the training model.

  • Returns the names of the user outputs for the evaluation model that can be associated with the ORTValue returned by the evalStep.

    Declaration

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalOutputNames() throws -> [String]

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    The names of the user outputs for the evaluation model.

  • Registers a linear learning rate scheduler for the training session.

    The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training. The decrease is performed by multiplying the current learning rate by a linearly updated factor. Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase.

    Declaration

    Objective-C

    - (BOOL)
        registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
                                      totalStepCount:(int64_t)totalStepCount
                                           initialLr:(float)initialLr
                                               error:(NSError *_Nullable *_Nullable)
                                                         error;

    Swift

    func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws

    Parameters

    warmupStepCount

    The number of steps to perform the linear warmup.

    totalStepCount

    The total number of steps to perform the linear decay.

    initialLr

    The initial learning rate.

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the scheduler was registered successfully, NO otherwise.

  • Update the learning rate based on the registered learning rate scheduler.

    Performs a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as necessary to update the learning rate being used by the training session.

    Note

    A valid predefined learning rate scheduler must be first registered to invoke this method.

    Declaration

    Objective-C

    - (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func schedulerStep() throws

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the scheduler step was performed successfully, NO otherwise.

  • Returns the current learning rate being used by the training session.

    Declaration

    Objective-C

    - (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func getLearningRate() throws -> Float

    Parameters

    error

    Optional error information set if an error occurs.

    Return Value

    The current learning rate or 0.0f if an error occurs.

  • Sets the learning rate being used by the training session.

    The current learning rate is maintained by the training session and can be overwritten by invoking this method with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session.

    Note

    It does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, use the registerLinearLRScheduler method.

    Declaration

    Objective-C

    - (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;

    Swift

    func setLearningRate(_ lr: Float) throws

    Parameters

    lr

    The learning rate to be used by the training session.

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the learning rate was set successfully, NO otherwise.

  • Loads the training session model parameters from a contiguous buffer.

    Declaration

    Objective-C

    - (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func fromBuffer(with buffer: ORTValue) throws

    Parameters

    buffer

    Contiguous buffer to load the parameters from.

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the parameters were loaded successfully, NO otherwise.

  • Returns a contiguous buffer that holds a copy of all training state parameters.

    Declaration

    Objective-C

    - (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable
                                           error:
                                               (NSError *_Nullable *_Nullable)error;

    Swift

    func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue

    Parameters

    onlyTrainable

    If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer that holds all the parameters.

    error

    Optional error information set if an error occurs.

    Return Value

    A contiguous buffer that holds a copy of all training state parameters.

  • Exports the training session model that can be used for inference.

    If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model’s outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with ORTSession.

    Note

    The method reloads the eval model from the path provided to the initializer and expects this path to be valid.

    Declaration

    Objective-C

    - (BOOL)
        exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath
                             graphOutputNames:
                                 (nonnull NSArray<NSString *> *)graphOutputNames
                                        error:(NSError *_Nullable *_Nullable)error;

    Swift

    func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws

    Parameters

    inferenceModelPath

    The path to the serialized the inference model.

    graphOutputNames

    The names of the outputs that are needed in the inference model.

    error

    Optional error information set if an error occurs.

    Return Value

    YES if the inference model was exported successfully, NO otherwise.