Source code for onnxruntime.training.api.optimizer

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from __future__ import annotations

import os
from typing import TYPE_CHECKING

from onnxruntime.capi import _pybind_state as C

if TYPE_CHECKING:
    from onnxruntime.training.api.module import Module


[docs]class Optimizer: """Class that provides methods to update the model parameters based on the computed gradients. Args: optimizer_uri: The path to the optimizer model. model: The module to be trained. """ def __init__(self, optimizer_uri: str | os.PathLike, module: Module): self._optimizer = C.Optimizer( os.fspath(optimizer_uri), module._state._state, module._device, module._session_options )
[docs] def step(self) -> None: """Updates the model parameters based on the computed gradients. This method updates the model parameters by taking a step in the direction of the computed gradients. The optimizer used depends on the optimizer model provided. """ self._optimizer.optimizer_step()
[docs] def set_learning_rate(self, learning_rate: float) -> None: """Sets the learning rate for the optimizer. Args: learning_rate: The learning rate to be set. """ self._optimizer.set_learning_rate(learning_rate)
[docs] def get_learning_rate(self) -> float: """Gets the current learning rate of the optimizer. Returns: The current learning rate. """ return self._optimizer.get_learning_rate()