ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_c_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// This file contains the training c apis.
5
6#pragma once
7#include <stdbool.h>
8#include "onnxruntime_c_api.h"
9
104ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
105ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
106
114
125
143 ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
144 _Outptr_ OrtCheckpointState** checkpoint_state);
145
159 ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
160 const bool include_optimizer_state);
161
163
166
191 ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
192 _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
193 _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
194 _Outptr_result_maybenull_ OrtTrainingSession** out);
195
211 ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
212 _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
213 _In_ const void* train_model_data, size_t train_data_length,
214 _In_ const void* eval_model_data, size_t eval_data_length,
215 _In_ const void* optim_model_data, size_t optim_data_length,
216 _Outptr_result_maybenull_ OrtTrainingSession** out);
217
219
222
234 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
235
247 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
248
262 ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
263
277 ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
278
280
283
295 ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
296
318 ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
319 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
320 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
321
337 ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
338 _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
339 _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
340
359 ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
360
373 ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
374
389 ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
390 _In_opt_ const OrtRunOptions* run_options);
391
407 ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
408 _In_ const int64_t total_step_count, _In_ const float initial_lr);
409
423 ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
424
426
429
442 ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
443
460 ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
461 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
462
481 ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
482 _Inout_ OrtValue* parameters_buffer, bool trainable_only);
483
485
488
495 ORT_CLASS_RELEASE(TrainingSession);
496
504 ORT_CLASS_RELEASE(CheckpointState);
505
507
510
527 ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
528 _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
529 _In_reads_(graph_outputs_len) const char* const* graph_output_names);
530
532
535
545 ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
546
548
551
562 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
563
575 ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
576
590 ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
591 _In_ OrtAllocator* allocator, _Outptr_ char** output);
592
606 ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
607 _In_ OrtAllocator* allocator, _Outptr_ char** output);
608
610
613
628 ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
629 _In_ const char* property_name, _In_ enum OrtPropertyType property_type,
630 _In_ void* property_value);
631
646 ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
647 _In_ const char* property_name, _Inout_ OrtAllocator* allocator,
648 _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
649
651
654
672 ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
673 _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
674
687 ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
688 _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
689
704 ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
705 _In_ const char* parameter_name, _In_ OrtValue* parameter);
706
722 ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
723 _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
724 _Outptr_ OrtValue** parameter);
725
727};
728
730
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:288
struct OrtRunOptions OrtRunOptions
Definition onnxruntime_c_api.h:286
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:292
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:285
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:280
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
OrtPropertyType
Type of property to be added to or returned from the OrtCheckpointState.
Definition onnxruntime_training_c_api.h:109
@ OrtIntProperty
Definition onnxruntime_training_c_api.h:110
@ OrtStringProperty
Definition onnxruntime_training_c_api.h:112
@ OrtFloatProperty
Definition onnxruntime_training_c_api.h:111
Memory allocation interface.
Definition onnxruntime_c_api.h:320
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
OrtStatus * CreateTrainingSessionFromBuffer(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training. This api provides a way to lo...
OrtStatus * CopyBufferToParameters(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training stat...
OrtStatus * EvalStep(const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs for the eval model for the given inputs.
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
OrtStatus * TrainingSessionGetEvalModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the eval model.
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
OrtStatus * TrainingSessionGetTrainingModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the training model.
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
OrtStatus * GetLearningRate(OrtTrainingSession *sess, float *learning_rate)
Gets the current learning rate for this training session.
OrtStatus * TrainingSessionGetEvalModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the eval model.
OrtStatus * TrainingSessionGetEvalModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the eval model.
OrtStatus * RegisterLinearLRScheduler(OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
Registers a linear learning rate scheduler for the training session.
OrtStatus * CopyParametersToBuffer(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
OrtStatus * GetParameterTypeAndShape(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape)
Retrieves the type and shape information of the parameter associated with the given parameter name.
OrtStatus * UpdateParameter(OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
OrtStatus * SetLearningRate(OrtTrainingSession *sess, float learning_rate)
Sets the learning rate for this training session.
OrtStatus * TrainingSessionGetTrainingModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the training model.
OrtStatus * TrainingSessionGetTrainingModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the training model.
OrtStatus * TrainingSessionGetEvalModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the eval model.
OrtStatus * AddProperty(OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
Adds or updates the given property to/in the checkpoint state.
OrtStatus * SchedulerStep(OrtTrainingSession *sess)
Update the learning rate based on the registered learing rate scheduler.
OrtStatus * GetParametersSize(OrtTrainingSession *sess, size_t *out, bool trainable_only)
Retrieves the size of all the parameters.
OrtStatus * SetSeed(const int64_t seed)
Sets the seed used for random number generation in Onnxruntime.
OrtStatus * LoadCheckpointFromBuffer(const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a buffer into checkpoint_state.
OrtStatus * GetProperty(const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
Gets the property value associated with the given name from the checkpoint state.
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
OrtStatus * GetParameter(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.