Class OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- All Implemented Interfaces:
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseableWraps an ONNX training model and allows training and inference calls.Allows the inspection of the model's input and output nodes. Produced by an
OrtEnvironment.Most instance methods throw
IllegalStateExceptionif the session is closed and the methods are called.
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description voidaddProperty(java.lang.String name, float value)Adds a float property to this training session checkpoint.voidaddProperty(java.lang.String name, int value)Adds a int property to this training session checkpoint.voidaddProperty(java.lang.String name, java.lang.String value)Adds a String property to this training session checkpoint.voidclose()OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)Performs a single evaluation step using the supplied inputs.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)Performs a single evaluation step using the supplied inputs.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)Performs a single evaluation step using the supplied inputs.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)Performs a single evaluation step using the supplied inputs.OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)Performs a single evaluation step using the supplied inputs.voidexportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.java.util.Set<java.lang.String>getEvalInputNames()Returns an ordered set of the eval model input names.java.util.Set<java.lang.String>getEvalOutputNames()Returns an ordered set of the eval model output names.floatgetFloatProperty(java.lang.String name)Gets a float property from this training session checkpoint.intgetIntProperty(java.lang.String name)Gets a int property from this training session checkpoint.floatgetLearningRate()Gets the current learning rate for this training session.java.lang.StringgetStringProperty(java.lang.String name)Gets a String property from this training session checkpoint.java.util.Set<java.lang.String>getTrainInputNames()Returns an ordered set of the train model input names.java.util.Set<java.lang.String>getTrainOutputNames()Returns an ordered set of the train model output names.voidlazyResetGrad()Ensures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).voidoptimizerStep()Applies the gradient updates to the trainable parameters using the optimizer model.voidoptimizerStep(OrtSession.RunOptions runOptions)Applies the gradient updates to the trainable parameters using the optimizer model.voidregisterLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)Registers a linear learning rate scheduler with linear warmup.voidsaveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)Save out the training session state into the supplied checkpoint directory.voidschedulerStep()Updates the learning rate based on the registered learning rate scheduler.voidsetLearningRate(float learningRate)Sets the learning rate for the training session.static voidsetSeed(long seed)Sets the RNG seed used by ONNX Runtime.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)Performs a single step of training, accumulating the gradients.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)Performs a single step of training, accumulating the gradients.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)Performs a single step of training, accumulating the gradients.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)Performs a single step of training, accumulating the gradients.OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)Performs a single step of training, accumulating the gradients.
-
-
-
Method Detail
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
Returns an ordered set of the train model input names.- Returns:
- The training inputs.
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
Returns an ordered set of the train model output names.- Returns:
- The training outputs.
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
Returns an ordered set of the eval model input names.- Returns:
- The evaluation inputs.
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
Returns an ordered set of the eval model output names.- Returns:
- The evaluation outputs.
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtExceptionAdds a float property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtExceptionAdds a int property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtExceptionAdds a String property to this training session checkpoint.- Parameters:
name- The property name.value- The property value.- Throws:
OrtException- If the call failed.
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtExceptionGets a float property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtExceptionGets a int property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtExceptionGets a String property from this training session checkpoint.- Parameters:
name- The property name.- Returns:
- The property value.
- Throws:
OrtException- If the property does not exist, or is of the wrong type.
-
close
public void close()
- Specified by:
closein interfacejava.lang.AutoCloseable
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtExceptionSave out the training session state into the supplied checkpoint directory.- Parameters:
outputPath- Path to a checkpoint directory.saveOptimizer- Should the optimizer states be saved out.- Throws:
OrtException- If the native call failed.
-
lazyResetGrad
public void lazyResetGrad() throws OrtExceptionEnsures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).Note this is a lazy call, the gradients are cleared as part of running the next
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)and not before.- Throws:
OrtException- If the native call failed.
-
setSeed
public static void setSeed(long seed) throws OrtExceptionSets the RNG seed used by ONNX Runtime.Note this setting is global across OrtTrainingSession instances.
- Parameters:
seed- The RNG seed.- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).- Returns:
- All outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).runOptions- Run options for controlling this specific call.- Returns:
- All outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs- The inputs (must include both the features and the target).requestedOutputs- The requested outputs.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs (must include both the features and the target).pinnedOutputs- The requested outputs which the user has allocated.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentExceptionis thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs (must include both the features and the target).requestedOutputs- The requested outputs which ORT will allocate.pinnedOutputs- The requested outputs which the user has allocated.runOptions- Run options for controlling this specific call.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.- Returns:
- All model outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.runOptions- Run options for controlling this specific call.- Returns:
- All model outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs- The model inputs.requestedOutputs- The requested output names.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs to score.pinnedOutputs- The requested outputs which the user has allocated.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentExceptionis thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Resultobject, and are not closed when the result object is closed.- Parameters:
inputs- The inputs to score.requestedOutputs- The requested outputs which ORT will allocate.pinnedOutputs- The requested outputs which the user has allocated.runOptions- Run options for controlling this specific call.- Returns:
- The requested outputs.
- Throws:
OrtException- If the native call failed.
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtExceptionSets the learning rate for the training session.Should be used only when there is no learning rate scheduler in the session. Not used to set the initial learning rate for LR schedulers.
- Parameters:
learningRate- The learning rate.- Throws:
OrtException- If the call failed.
-
getLearningRate
public float getLearningRate() throws OrtExceptionGets the current learning rate for this training session.- Returns:
- The current learning rate.
- Throws:
OrtException- If the call failed.
-
optimizerStep
public void optimizerStep() throws OrtExceptionApplies the gradient updates to the trainable parameters using the optimizer model.- Throws:
OrtException- If the native call failed.
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
Applies the gradient updates to the trainable parameters using the optimizer model.The run options can be used to control logging and to terminate the call early.
- Parameters:
runOptions- Options for controlling the model execution.- Throws:
OrtException- If the native call failed.
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtExceptionRegisters a linear learning rate scheduler with linear warmup.- Parameters:
warmupSteps- The number of steps to increase the learning rate from zero toinitialLearningRate.totalSteps- The total number of steps this scheduler operates over.initialLearningRate- The maximum learning rate.- Throws:
OrtException- If the native call failed.
-
schedulerStep
public void schedulerStep() throws OrtExceptionUpdates the learning rate based on the registered learning rate scheduler.- Throws:
OrtException- If the native call failed.
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtExceptionExports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.Note that this method reloads the evaluation model from the path provided to the training session, and this path must still be valid.
- Parameters:
outputPath- The path to write out the inference model.outputNames- The names of the output nodes.- Throws:
OrtException- If the native call failed.
-
-