ORTTrainingSession
Objective-C
@interface ORTTrainingSession : NSObjectSwift
class ORTTrainingSession : NSObjectTrainer class that provides methods to train, evaluate and optimize ONNX models.
The training session requires four training artifacts:
- Training onnx model
- Evaluation onnx model (optional)
- Optimizer onnx model
- Checkpoint directory
onnxruntime-training python utility can be used to generate above training artifacts.
Available since 1.16.
Note
This class is only available when the training APIs are enabled.- 
                  
                  Unavailable DeclarationObjective-C - (instancetype)init NS_UNAVAILABLE;
- 
                  
                  Creates a training session from the training artifacts that can be used to begin or resume training. The initializer instantiates the training session based on provided env and session options, which can be used to begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training session which will be moved to the device specified in the session option if needed. Note Note that the training session created with a checkpoint state uses this state to store the entire training state (including model parameters, its gradients, the optimizer states and the properties). The training session keeps a strong (owning) pointer to the checkpoint state. DeclarationObjective-C - (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env sessionOptions: (nullable ORTSessionOptions *)sessionOptions checkpoint:(nonnull ORTCheckpoint *)checkpoint trainModelPath:(nonnull NSString *)trainModelPath evalModelPath:(nullable NSString *)evalModelPath optimizerModelPath:(nullable NSString *)optimizerModelPath error:(NSError *_Nullable *_Nullable)error;Swift init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throwsParametersenvThe ORTEnvinstance to use for the training session.sessionOptionsThe optional ORTSessionOptionsto use for the training session.checkpointTraining states that are used as a starting point for training. trainModelPathThe path to the training onnx model. evalModelPathThe path to the evaluation onnx model. optimizerModelPathThe path to the optimizer onnx model used to perform gradient descent. errorOptional error information set if an error occurs. Return ValueThe instance, or nil if an error occurs. 
- 
                  
                  Performs a training step, which is equivalent to a forward and backward propagation in a single step. The training step computes the outputs of the training model and the gradients of the trainable parameters for the given input values. The train step is performed based on the training model that was provided to the training session. It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside the training session state so they can be later consumed by optimizerStep. The gradients can be lazily reset by callinglazyResetGradmethod.DeclarationParametersinputsThe input values to the training model. errorOptional error information set if an error occurs. Return ValueThe output values of the training model. 
- 
                  
                  Performs a evaluation step that computes the outputs of the evaluation model for the given inputs. The eval step is performed based on the evaluation model that was provided to the training session. DeclarationParametersinputsThe input values to the eval model. errorOptional error information set if an error occurs. Return ValueThe output values of the eval model. 
- 
                  
                  Reset the gradients of all trainable parameters to zero lazily. Calling this method sets the internal state of the training session such that the gradients of the trainable parameters in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next invocation of the trainStepmethod.DeclarationObjective-C - (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;Swift func lazyResetGrad() throwsParameterserrorOptional error information set if an error occurs. Return ValueYES if the gradients are set to reset successfully, NO otherwise. 
- 
                  
                  Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed based on the optimizer model that was provided to the training session. The updated parameters are stored inside the training state so that they can be used by the next trainStepmethod call.DeclarationObjective-C - (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;Swift func optimizerStep() throwsParameterserrorOptional error information set if an error occurs. Return ValueYES if the optimizer step was performed successfully, NO otherwise. 
- 
                  
                  Returns the names of the user inputs for the training model that can be associated with the ORTValueprovided to thetrainStep.DeclarationObjective-C - (nullable NSArray<NSString *> *)getTrainInputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift func getTrainInputNames() throws -> [String]ParameterserrorOptional error information set if an error occurs. Return ValueThe names of the user inputs for the training model. 
- 
                  
                  Returns the names of the user inputs for the evaluation model that can be associated with the ORTValueprovided to theevalStep.DeclarationObjective-C - (nullable NSArray<NSString *> *)getEvalInputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift func getEvalInputNames() throws -> [String]ParameterserrorOptional error information set if an error occurs. Return ValueThe names of the user inputs for the evaluation model. 
- 
                  
                  Returns the names of the user outputs for the training model that can be associated with the ORTValuereturned by thetrainStep.DeclarationObjective-C - (nullable NSArray<NSString *> *)getTrainOutputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift func getTrainOutputNames() throws -> [String]ParameterserrorOptional error information set if an error occurs. Return ValueThe names of the user outputs for the training model. 
- 
                  
                  Returns the names of the user outputs for the evaluation model that can be associated with the ORTValuereturned by theevalStep.DeclarationObjective-C - (nullable NSArray<NSString *> *)getEvalOutputNamesWithError: (NSError *_Nullable *_Nullable)error;Swift func getEvalOutputNames() throws -> [String]ParameterserrorOptional error information set if an error occurs. Return ValueThe names of the user outputs for the evaluation model. 
- 
                  
                  Registers a linear learning rate scheduler for the training session. The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training. The decrease is performed by multiplying the current learning rate by a linearly updated factor. Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase. DeclarationObjective-C - (BOOL) registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount totalStepCount:(int64_t)totalStepCount initialLr:(float)initialLr error:(NSError *_Nullable *_Nullable) error;Swift func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throwsParameterswarmupStepCountThe number of steps to perform the linear warmup. totalStepCountThe total number of steps to perform the linear decay. initialLrThe initial learning rate. errorOptional error information set if an error occurs. Return ValueYES if the scheduler was registered successfully, NO otherwise. 
- 
                  
                  Update the learning rate based on the registered learning rate scheduler. Performs a scheduler step that updates the learning rate that is being used by the training session. This function should typically be called before invoking the optimizer step for each round, or as necessary to update the learning rate being used by the training session. Note A valid predefined learning rate scheduler must be first registered to invoke this method. DeclarationObjective-C - (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;Swift func schedulerStep() throwsParameterserrorOptional error information set if an error occurs. Return ValueYES if the scheduler step was performed successfully, NO otherwise. 
- 
                  
                  Returns the current learning rate being used by the training session. DeclarationObjective-C - (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;Swift func getLearningRate() throws -> FloatParameterserrorOptional error information set if an error occurs. Return ValueThe current learning rate or 0.0f if an error occurs. 
- 
                  
                  Sets the learning rate being used by the training session. The current learning rate is maintained by the training session and can be overwritten by invoking this method with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant learning rate to be used throughout the training session. Note It does not set the initial learning rate that may be needed by the predefined learning rate schedulers. To set the initial learning rate for learning rate schedulers, use the registerLinearLRSchedulermethod.DeclarationObjective-C - (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;Swift func setLearningRate(_ lr: Float) throwsParameterslrThe learning rate to be used by the training session. errorOptional error information set if an error occurs. Return ValueYES if the learning rate was set successfully, NO otherwise. 
- 
                  
                  Loads the training session model parameters from a contiguous buffer. DeclarationObjective-C - (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer error:(NSError *_Nullable *_Nullable)error;Swift func fromBuffer(with buffer: ORTValue) throwsParametersbufferContiguous buffer to load the parameters from. errorOptional error information set if an error occurs. Return ValueYES if the parameters were loaded successfully, NO otherwise. 
- 
                  
                  Returns a contiguous buffer that holds a copy of all training state parameters. DeclarationObjective-C - (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable error: (NSError *_Nullable *_Nullable)error;Swift func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValueParametersonlyTrainableIf YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer that holds all the parameters. errorOptional error information set if an error occurs. Return ValueA contiguous buffer that holds a copy of all training state parameters. 
- 
                  
                  Exports the training session model that can be used for inference. If the training session was provided with an eval model, the training session can generate an inference model if it knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the inference model’s outputs align with the provided outputs. The exported model is saved at the path provided and can be used for inferencing with ORTSession.Note The method reloads the eval model from the path provided to the initializer and expects this path to be valid. DeclarationObjective-C - (BOOL) exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath graphOutputNames: (nonnull NSArray<NSString *> *)graphOutputNames error:(NSError *_Nullable *_Nullable)error;Swift func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throwsParametersinferenceModelPathThe path to the serialized the inference model. graphOutputNamesThe names of the outputs that are needed in the inference model. errorOptional error information set if an error occurs. Return ValueYES if the inference model was exported successfully, NO otherwise. 
 View on GitHub
View on GitHub ORTTrainingSession Class Reference
        ORTTrainingSession Class Reference