Class OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- All Implemented Interfaces:
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseable
Wraps an ONNX training model and allows training and inference calls.Allows the inspection of the model's input and output nodes. Produced by an
OrtEnvironment
.Most instance methods throw
IllegalStateException
if the session is closed and the methods are called.
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
addProperty(java.lang.String name, float value)
Adds a float property to this training session checkpoint.void
addProperty(java.lang.String name, int value)
Adds a int property to this training session checkpoint.void
addProperty(java.lang.String name, java.lang.String value)
Adds a String property to this training session checkpoint.void
close()
OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)
Performs a single evaluation step using the supplied inputs.OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
Performs a single evaluation step using the supplied inputs.OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)
Performs a single evaluation step using the supplied inputs.OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
Performs a single evaluation step using the supplied inputs.OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
Performs a single evaluation step using the supplied inputs.void
exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)
Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.java.util.Set<java.lang.String>
getEvalInputNames()
Returns an ordered set of the eval model input names.java.util.Set<java.lang.String>
getEvalOutputNames()
Returns an ordered set of the eval model output names.float
getFloatProperty(java.lang.String name)
Gets a float property from this training session checkpoint.int
getIntProperty(java.lang.String name)
Gets a int property from this training session checkpoint.float
getLearningRate()
Gets the current learning rate for this training session.java.lang.String
getStringProperty(java.lang.String name)
Gets a String property from this training session checkpoint.java.util.Set<java.lang.String>
getTrainInputNames()
Returns an ordered set of the train model input names.java.util.Set<java.lang.String>
getTrainOutputNames()
Returns an ordered set of the train model output names.void
lazyResetGrad()
Ensures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)
.void
optimizerStep()
Applies the gradient updates to the trainable parameters using the optimizer model.void
optimizerStep(OrtSession.RunOptions runOptions)
Applies the gradient updates to the trainable parameters using the optimizer model.void
registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)
Registers a linear learning rate scheduler with linear warmup.void
saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)
Save out the training session state into the supplied checkpoint directory.void
schedulerStep()
Updates the learning rate based on the registered learning rate scheduler.void
setLearningRate(float learningRate)
Sets the learning rate for the training session.static void
setSeed(long seed)
Sets the RNG seed used by ONNX Runtime.OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)
Performs a single step of training, accumulating the gradients.OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
Performs a single step of training, accumulating the gradients.OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)
Performs a single step of training, accumulating the gradients.OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
Performs a single step of training, accumulating the gradients.OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
Performs a single step of training, accumulating the gradients.
-
-
-
Method Detail
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
Returns an ordered set of the train model input names.- Returns:
- The training inputs.
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
Returns an ordered set of the train model output names.- Returns:
- The training outputs.
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
Returns an ordered set of the eval model input names.- Returns:
- The evaluation inputs.
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
Returns an ordered set of the eval model output names.- Returns:
- The evaluation outputs.
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtException
Adds a float property to this training session checkpoint.- Parameters:
name
- The property name.value
- The property value.- Throws:
OrtException
- If the call failed.
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtException
Adds a int property to this training session checkpoint.- Parameters:
name
- The property name.value
- The property value.- Throws:
OrtException
- If the call failed.
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtException
Adds a String property to this training session checkpoint.- Parameters:
name
- The property name.value
- The property value.- Throws:
OrtException
- If the call failed.
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtException
Gets a float property from this training session checkpoint.- Parameters:
name
- The property name.- Returns:
- The property value.
- Throws:
OrtException
- If the property does not exist, or is of the wrong type.
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtException
Gets a int property from this training session checkpoint.- Parameters:
name
- The property name.- Returns:
- The property value.
- Throws:
OrtException
- If the property does not exist, or is of the wrong type.
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtException
Gets a String property from this training session checkpoint.- Parameters:
name
- The property name.- Returns:
- The property value.
- Throws:
OrtException
- If the property does not exist, or is of the wrong type.
-
close
public void close()
- Specified by:
close
in interfacejava.lang.AutoCloseable
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtException
Save out the training session state into the supplied checkpoint directory.- Parameters:
outputPath
- Path to a checkpoint directory.saveOptimizer
- Should the optimizer states be saved out.- Throws:
OrtException
- If the native call failed.
-
lazyResetGrad
public void lazyResetGrad() throws OrtException
Ensures the gradients are reset to zero before the next call totrainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)
.Note this is a lazy call, the gradients are cleared as part of running the next
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)
and not before.- Throws:
OrtException
- If the native call failed.
-
setSeed
public static void setSeed(long seed) throws OrtException
Sets the RNG seed used by ONNX Runtime.Note this setting is global across OrtTrainingSession instances.
- Parameters:
seed
- The RNG seed.- Throws:
OrtException
- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs
- The inputs (must include both the features and the target).- Returns:
- All outputs produced by the training step.
- Throws:
OrtException
- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs
- The inputs (must include both the features and the target).runOptions
- Run options for controlling this specific call.- Returns:
- All outputs produced by the training step.
- Throws:
OrtException
- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Performs a single step of training, accumulating the gradients.- Parameters:
inputs
- The inputs (must include both the features and the target).requestedOutputs
- The requested outputs.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException
- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Result
object, and are not closed when the result object is closed.- Parameters:
inputs
- The inputs (must include both the features and the target).pinnedOutputs
- The requested outputs which the user has allocated.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException
- If the native call failed.
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single step of training, accumulating the gradients.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentException
is thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Result
object, and are not closed when the result object is closed.- Parameters:
inputs
- The inputs (must include both the features and the target).requestedOutputs
- The requested outputs which ORT will allocate.pinnedOutputs
- The requested outputs which the user has allocated.runOptions
- Run options for controlling this specific call.- Returns:
- Requested outputs produced by the training step.
- Throws:
OrtException
- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs
- The model inputs.- Returns:
- All model outputs.
- Throws:
OrtException
- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs
- The model inputs.runOptions
- Run options for controlling this specific call.- Returns:
- All model outputs.
- Throws:
OrtException
- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
Performs a single evaluation step using the supplied inputs.- Parameters:
inputs
- The model inputs.requestedOutputs
- The requested output names.- Returns:
- The requested outputs.
- Throws:
OrtException
- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied map traversal order.
Note: pinned outputs are not owned by the
OrtSession.Result
object, and are not closed when the result object is closed.- Parameters:
inputs
- The inputs to score.pinnedOutputs
- The requested outputs which the user has allocated.- Returns:
- The requested outputs.
- Throws:
OrtException
- If the native call failed.
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
Performs a single evaluation step using the supplied inputs.The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An
IllegalArgumentException
is thrown if the same output name appears in both the requested outputs and the pinned outputs.Note: pinned outputs are not owned by the
OrtSession.Result
object, and are not closed when the result object is closed.- Parameters:
inputs
- The inputs to score.requestedOutputs
- The requested outputs which ORT will allocate.pinnedOutputs
- The requested outputs which the user has allocated.runOptions
- Run options for controlling this specific call.- Returns:
- The requested outputs.
- Throws:
OrtException
- If the native call failed.
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtException
Sets the learning rate for the training session.Should be used only when there is no learning rate scheduler in the session. Not used to set the initial learning rate for LR schedulers.
- Parameters:
learningRate
- The learning rate.- Throws:
OrtException
- If the call failed.
-
getLearningRate
public float getLearningRate() throws OrtException
Gets the current learning rate for this training session.- Returns:
- The current learning rate.
- Throws:
OrtException
- If the call failed.
-
optimizerStep
public void optimizerStep() throws OrtException
Applies the gradient updates to the trainable parameters using the optimizer model.- Throws:
OrtException
- If the native call failed.
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
Applies the gradient updates to the trainable parameters using the optimizer model.The run options can be used to control logging and to terminate the call early.
- Parameters:
runOptions
- Options for controlling the model execution.- Throws:
OrtException
- If the native call failed.
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException
Registers a linear learning rate scheduler with linear warmup.- Parameters:
warmupSteps
- The number of steps to increase the learning rate from zero toinitialLearningRate
.totalSteps
- The total number of steps this scheduler operates over.initialLearningRate
- The maximum learning rate.- Throws:
OrtException
- If the native call failed.
-
schedulerStep
public void schedulerStep() throws OrtException
Updates the learning rate based on the registered learning rate scheduler.- Throws:
OrtException
- If the native call failed.
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtException
Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.Note that this method reloads the evaluation model from the path provided to the training session, and this path must still be valid.
- Parameters:
outputPath
- The path to write out the inference model.outputNames
- The names of the output nodes.- Throws:
OrtException
- If the native call failed.
-
-