Source code for onnxruntime.training.api.checkpoint_state

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from __future__ import annotations

import os

import numpy as np

from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue


[docs]class Parameter: """Class that represents a model parameter This class represents a model parameter and provides access to its data, gradient and other properties. This class is not expected to be instantiated directly. Instead, it is returned by the `CheckpointState` object. Args: parameter: The C.Parameter object that holds the underlying parameter data. state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, parameter: C.Parameter, state: C.CheckpointState): self._parameter = parameter self._state = state @property def name(self) -> str: """The name of the parameter""" return self._parameter.name @property def data(self) -> np.ndarray: """The data of the parameter""" return self._parameter.data.numpy() @data.setter def data(self, value: np.ndarray) -> None: """Sets the data of the parameter""" self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) @property def grad(self) -> np.ndarray: """The gradient of the parameter""" return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None @property def requires_grad(self) -> bool: """Whether or not the parameter requires its gradient to be computed""" return self._parameter.requires_grad
[docs] def __repr__(self) -> str: """Returns a string representation of the parameter""" return f"Parameter(name={self.name}, requires_grad={self.requires_grad})"
[docs]class Parameters: """Class that holds all the model parameters This class holds all the model parameters and provides access to them. This class is not expected to be instantiated directly. Instead, it is returned by the `CheckpointState`'s parameters attribute. This class behaves like a dictionary and provides access to the parameters by name. Args: state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): self._state = state
[docs] def __getitem__(self, name: str) -> Parameter: """Gets the parameter associated with the given name Searches for the name in the parameters of the checkpoint state. Args: name: The name of the parameter Returns: The value of the parameter Raises: KeyError: If the parameter is not found """ if name not in self: raise KeyError(f"Parameter {name} not found.") return Parameter(self._state.get_parameter(name), self._state)
[docs] def __setitem__(self, name: str, value: np.ndarray) -> None: """Sets the parameter value for the given name Searches for the name in the parameters of the checkpoint state. If the name is found in parameters, the value is updated. Args: name: The name of the parameter value: The value of the parameter as a numpy array Raises: KeyError: If the parameter is not found """ if name not in self: raise KeyError(f"Parameter {name} not found.") self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue)
[docs] def __contains__(self, name: str) -> bool: """Checks if the parameter exists in the state Args: name: The name of the parameter Returns: True if the name is a parameter False otherwise """ return self._state.has_parameter(name)
[docs] def __iter__(self): """Returns an iterator over the properties""" for parameter_name in self._state.parameter_names(): yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state)
[docs] def __repr__(self) -> str: """Returns a string representation of the parameters""" return self._state.parameter_names()
[docs] def __len__(self) -> int: """Returns the number of parameters""" return len(self._state.parameter_names())
[docs]class Properties: def __init__(self, state: C.CheckpointState): self._state = state
[docs] def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name Searches for the name in the properties of the checkpoint state. Args: name: The name of the property Returns: The value of the property Raises: KeyError: If the property is not found """ if name not in self: raise KeyError(f"Property {name} not found.") return self._state.get_property(name)
[docs] def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name Searches for the name in the properties of the checkpoint state. The value is added or updated in the properties. Args: name: The name of the property value: The value of the property Properties only support int, float and str values. """ self._state.add_property(name, value)
[docs] def __contains__(self, name: str) -> bool: """Checks if the property exists in the state Args: name: The name of the property Returns: True if the name is a property, False otherwise """ return self._state.has_property(name)
[docs] def __iter__(self): """Returns an iterator over the properties""" for property_name in self._state.property_names(): yield property_name, self._state.get_property(property_name)
[docs] def __repr__(self) -> str: """Returns a string representation of the properties""" return self._state.property_names()
[docs] def __len__(self) -> int: """Returns the number of properties""" return len(self._state.property_names())
[docs]class CheckpointState: """Class that holds the state of the training session This class holds all the state information of the training session such as the model parameters, its gradients, the optimizer state and user defined properties. To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. Args: state: The C.Checkpoint state object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): if not isinstance(state, C.CheckpointState): raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state self._parameters = Parameters(self._state) self._properties = Properties(self._state)
[docs] @classmethod def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: """Loads the checkpoint state from the checkpoint file Args: checkpoint_uri: The path to the checkpoint file. Returns: CheckpointState: The checkpoint state object. """ return cls(C.load_checkpoint(os.fspath(checkpoint_uri)))
[docs] @classmethod def save_checkpoint( cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False ) -> None: """Saves the checkpoint state to the checkpoint file Args: state: The checkpoint state object. checkpoint_uri: The path to the checkpoint file. include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. """ C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state)
@property def parameters(self) -> Parameters: """Returns the model parameters from the checkpoint state""" return self._parameters @property def properties(self) -> Properties: """Returns the properties from the checkpoint state""" return self._properties