Train the Model on the Device#

Once the training artifacts are generated, the model can be trained on the device using the onnxruntime training python API.

The expected training artifacts are:

  1. The training onnx model

  2. The checkpoint state

  3. The optimizer onnx model

  4. The eval onnx model (optional)

Sample usage:

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)