Class OrtTrainingSession

  • All Implemented Interfaces:
    java.lang.AutoCloseable

    public final class OrtTrainingSession
    extends java.lang.Object
    implements java.lang.AutoCloseable
    Wraps an ONNX training model and allows training and inference calls.

    Allows the inspection of the model's input and output nodes. Produced by an OrtEnvironment.

    Most instance methods throw IllegalStateException if the session is closed and the methods are called.

    • Method Summary

      All Methods Static Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void addProperty​(java.lang.String name, float value)
      Adds a float property to this training session checkpoint.
      void addProperty​(java.lang.String name, int value)
      Adds a int property to this training session checkpoint.
      void addProperty​(java.lang.String name, java.lang.String value)
      Adds a String property to this training session checkpoint.
      void close()  
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
      Performs a single evaluation step using the supplied inputs.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
      Performs a single evaluation step using the supplied inputs.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
      Performs a single evaluation step using the supplied inputs.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
      Performs a single evaluation step using the supplied inputs.
      OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
      Performs a single evaluation step using the supplied inputs.
      void exportModelForInference​(java.nio.file.Path outputPath, java.lang.String[] outputNames)
      Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.
      java.util.Set<java.lang.String> getEvalInputNames()
      Returns an ordered set of the eval model input names.
      java.util.Set<java.lang.String> getEvalOutputNames()
      Returns an ordered set of the eval model output names.
      float getFloatProperty​(java.lang.String name)
      Gets a float property from this training session checkpoint.
      int getIntProperty​(java.lang.String name)
      Gets a int property from this training session checkpoint.
      float getLearningRate()
      Gets the current learning rate for this training session.
      java.lang.String getStringProperty​(java.lang.String name)
      Gets a String property from this training session checkpoint.
      java.util.Set<java.lang.String> getTrainInputNames()
      Returns an ordered set of the train model input names.
      java.util.Set<java.lang.String> getTrainOutputNames()
      Returns an ordered set of the train model output names.
      void lazyResetGrad()
      Ensures the gradients are reset to zero before the next call to trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>).
      void optimizerStep()
      Applies the gradient updates to the trainable parameters using the optimizer model.
      void optimizerStep​(OrtSession.RunOptions runOptions)
      Applies the gradient updates to the trainable parameters using the optimizer model.
      void registerLinearLRScheduler​(long warmupSteps, long totalSteps, float initialLearningRate)
      Registers a linear learning rate scheduler with linear warmup.
      void saveCheckpoint​(java.nio.file.Path outputPath, boolean saveOptimizer)
      Save out the training session state into the supplied checkpoint directory.
      void schedulerStep()
      Updates the learning rate based on the registered learning rate scheduler.
      void setLearningRate​(float learningRate)
      Sets the learning rate for the training session.
      static void setSeed​(long seed)
      Sets the RNG seed used by ONNX Runtime.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
      Performs a single step of training, accumulating the gradients.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
      Performs a single step of training, accumulating the gradients.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
      Performs a single step of training, accumulating the gradients.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
      Performs a single step of training, accumulating the gradients.
      OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
      Performs a single step of training, accumulating the gradients.
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Method Detail

      • getTrainInputNames

        public java.util.Set<java.lang.String> getTrainInputNames()
        Returns an ordered set of the train model input names.
        Returns:
        The training inputs.
      • getTrainOutputNames

        public java.util.Set<java.lang.String> getTrainOutputNames()
        Returns an ordered set of the train model output names.
        Returns:
        The training outputs.
      • getEvalInputNames

        public java.util.Set<java.lang.String> getEvalInputNames()
        Returns an ordered set of the eval model input names.
        Returns:
        The evaluation inputs.
      • getEvalOutputNames

        public java.util.Set<java.lang.String> getEvalOutputNames()
        Returns an ordered set of the eval model output names.
        Returns:
        The evaluation outputs.
      • addProperty

        public void addProperty​(java.lang.String name,
                                float value)
                         throws OrtException
        Adds a float property to this training session checkpoint.
        Parameters:
        name - The property name.
        value - The property value.
        Throws:
        OrtException - If the call failed.
      • addProperty

        public void addProperty​(java.lang.String name,
                                int value)
                         throws OrtException
        Adds a int property to this training session checkpoint.
        Parameters:
        name - The property name.
        value - The property value.
        Throws:
        OrtException - If the call failed.
      • addProperty

        public void addProperty​(java.lang.String name,
                                java.lang.String value)
                         throws OrtException
        Adds a String property to this training session checkpoint.
        Parameters:
        name - The property name.
        value - The property value.
        Throws:
        OrtException - If the call failed.
      • getFloatProperty

        public float getFloatProperty​(java.lang.String name)
                               throws OrtException
        Gets a float property from this training session checkpoint.
        Parameters:
        name - The property name.
        Returns:
        The property value.
        Throws:
        OrtException - If the property does not exist, or is of the wrong type.
      • getIntProperty

        public int getIntProperty​(java.lang.String name)
                           throws OrtException
        Gets a int property from this training session checkpoint.
        Parameters:
        name - The property name.
        Returns:
        The property value.
        Throws:
        OrtException - If the property does not exist, or is of the wrong type.
      • getStringProperty

        public java.lang.String getStringProperty​(java.lang.String name)
                                           throws OrtException
        Gets a String property from this training session checkpoint.
        Parameters:
        name - The property name.
        Returns:
        The property value.
        Throws:
        OrtException - If the property does not exist, or is of the wrong type.
      • close

        public void close()
        Specified by:
        close in interface java.lang.AutoCloseable
      • saveCheckpoint

        public void saveCheckpoint​(java.nio.file.Path outputPath,
                                   boolean saveOptimizer)
                            throws OrtException
        Save out the training session state into the supplied checkpoint directory.
        Parameters:
        outputPath - Path to a checkpoint directory.
        saveOptimizer - Should the optimizer states be saved out.
        Throws:
        OrtException - If the native call failed.
      • setSeed

        public static void setSeed​(long seed)
                            throws OrtException
        Sets the RNG seed used by ONNX Runtime.

        Note this setting is global across OrtTrainingSession instances.

        Parameters:
        seed - The RNG seed.
        Throws:
        OrtException - If the native call failed.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                    throws OrtException
        Performs a single step of training, accumulating the gradients.
        Parameters:
        inputs - The inputs (must include both the features and the target).
        Returns:
        All outputs produced by the training step.
        Throws:
        OrtException - If the native call failed.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        Performs a single step of training, accumulating the gradients.
        Parameters:
        inputs - The inputs (must include both the features and the target).
        runOptions - Run options for controlling this specific call.
        Returns:
        All outputs produced by the training step.
        Throws:
        OrtException - If the native call failed.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs)
                                    throws OrtException
        Performs a single step of training, accumulating the gradients.
        Parameters:
        inputs - The inputs (must include both the features and the target).
        requestedOutputs - The requested outputs.
        Returns:
        Requested outputs produced by the training step.
        Throws:
        OrtException - If the native call failed.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                    throws OrtException
        Performs a single step of training, accumulating the gradients.

        The outputs are sorted based on the supplied map traversal order.

        Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

        Parameters:
        inputs - The inputs (must include both the features and the target).
        pinnedOutputs - The requested outputs which the user has allocated.
        Returns:
        Requested outputs produced by the training step.
        Throws:
        OrtException - If the native call failed.
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        Performs a single step of training, accumulating the gradients.

        The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An IllegalArgumentException is thrown if the same output name appears in both the requested outputs and the pinned outputs.

        Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

        Parameters:
        inputs - The inputs (must include both the features and the target).
        requestedOutputs - The requested outputs which ORT will allocate.
        pinnedOutputs - The requested outputs which the user has allocated.
        runOptions - Run options for controlling this specific call.
        Returns:
        Requested outputs produced by the training step.
        Throws:
        OrtException - If the native call failed.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                   throws OrtException
        Performs a single evaluation step using the supplied inputs.
        Parameters:
        inputs - The model inputs.
        Returns:
        All model outputs.
        Throws:
        OrtException - If the native call failed.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        Performs a single evaluation step using the supplied inputs.
        Parameters:
        inputs - The model inputs.
        runOptions - Run options for controlling this specific call.
        Returns:
        All model outputs.
        Throws:
        OrtException - If the native call failed.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs)
                                   throws OrtException
        Performs a single evaluation step using the supplied inputs.
        Parameters:
        inputs - The model inputs.
        requestedOutputs - The requested output names.
        Returns:
        The requested outputs.
        Throws:
        OrtException - If the native call failed.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                   throws OrtException
        Performs a single evaluation step using the supplied inputs.

        The outputs are sorted based on the supplied map traversal order.

        Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

        Parameters:
        inputs - The inputs to score.
        pinnedOutputs - The requested outputs which the user has allocated.
        Returns:
        The requested outputs.
        Throws:
        OrtException - If the native call failed.
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        Performs a single evaluation step using the supplied inputs.

        The outputs are sorted based on the supplied set traversal order with pinned outputs first, then requested outputs. An IllegalArgumentException is thrown if the same output name appears in both the requested outputs and the pinned outputs.

        Note: pinned outputs are not owned by the OrtSession.Result object, and are not closed when the result object is closed.

        Parameters:
        inputs - The inputs to score.
        requestedOutputs - The requested outputs which ORT will allocate.
        pinnedOutputs - The requested outputs which the user has allocated.
        runOptions - Run options for controlling this specific call.
        Returns:
        The requested outputs.
        Throws:
        OrtException - If the native call failed.
      • setLearningRate

        public void setLearningRate​(float learningRate)
                             throws OrtException
        Sets the learning rate for the training session.

        Should be used only when there is no learning rate scheduler in the session. Not used to set the initial learning rate for LR schedulers.

        Parameters:
        learningRate - The learning rate.
        Throws:
        OrtException - If the call failed.
      • getLearningRate

        public float getLearningRate()
                              throws OrtException
        Gets the current learning rate for this training session.
        Returns:
        The current learning rate.
        Throws:
        OrtException - If the call failed.
      • optimizerStep

        public void optimizerStep()
                           throws OrtException
        Applies the gradient updates to the trainable parameters using the optimizer model.
        Throws:
        OrtException - If the native call failed.
      • optimizerStep

        public void optimizerStep​(OrtSession.RunOptions runOptions)
                           throws OrtException
        Applies the gradient updates to the trainable parameters using the optimizer model.

        The run options can be used to control logging and to terminate the call early.

        Parameters:
        runOptions - Options for controlling the model execution.
        Throws:
        OrtException - If the native call failed.
      • registerLinearLRScheduler

        public void registerLinearLRScheduler​(long warmupSteps,
                                              long totalSteps,
                                              float initialLearningRate)
                                       throws OrtException
        Registers a linear learning rate scheduler with linear warmup.
        Parameters:
        warmupSteps - The number of steps to increase the learning rate from zero to initialLearningRate.
        totalSteps - The total number of steps this scheduler operates over.
        initialLearningRate - The maximum learning rate.
        Throws:
        OrtException - If the native call failed.
      • schedulerStep

        public void schedulerStep()
                           throws OrtException
        Updates the learning rate based on the registered learning rate scheduler.
        Throws:
        OrtException - If the native call failed.
      • exportModelForInference

        public void exportModelForInference​(java.nio.file.Path outputPath,
                                            java.lang.String[] outputNames)
                                     throws OrtException
        Exports the evaluation model as a model suitable for inference, setting the desired nodes as output nodes.

        Note that this method reloads the evaluation model from the path provided to the training session, and this path must still be valid.

        Parameters:
        outputPath - The path to write out the inference model.
        outputNames - The names of the output nodes.
        Throws:
        OrtException - If the native call failed.