ONNX Runtime Training’s
ORTModule offers a high performance training engine for models defined using the
ORTModule is designed to accelerate the training of large models without needing to change the model definition and with just a single line of code change (the
ORTModule wrap) to the entire training script.
Using the ORTModule class wrapper, ONNX Runtime runs the forward and backward pass of the training script using an optimized automatically-exported ONNX computation graph.
In this example we will go over how to use ORT for Training a model with PyTorch.
# Installs the torch_ort and onnxruntime-training Python packages pip install torch-ort # Configures onnxruntime-training to work with user's PyTorch installation python -m torch_ort.configure
Note: This installs the default version of the
onnxruntime-training packages that are mapped to specific versions of the CUDA libraries. Refer to the install options in onnxruntime.ai.
- Add ORTModule in the
+ from torch_ort import ORTModule . . . - model = build_model() # Users PyTorch model + model = ORTModule(build_model())