Prepare for training#
Before the training can start on edge devices, the training artifacts need to be generated in an offline step.
These artifacts include:
The training onnx model
The checkpoint state
The optimizer onnx model
The eval onnx model (optional)
It is assumed that the an forward only onnx model is already available. This model can be generated by exporting the PyTorch model using the torch.onnx.export()
API if using PyTorch.
Note
If using PyTorch to export the model, please use the following export arguments so training artifact generation can be successful:
export_params
:True
do_constant_folding
:False
training
:torch.onnx.TrainingMode.TRAINING
Once the forward only onnx model is available, the training artifacts can be generated using the onnxruntime.training.artifacts.generate_artifacts()
API.
Sample usage:
from onnxruntime.training import artifacts
# Load the forward only onnx model
model = onnx.load(path_to_forward_only_onnx_model)
# Generate the training artifacts
artifacts.generate_artifacts(model,
requires_grad = ["parameters", "needing", "gradients"],
frozen_params = ["parameters", "not", "needing", "gradients"],
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
artifact_directory = path_to_output_artifact_directory)
Custom Loss#
If a custom loss is needed, the user can provide a custom loss function to the onnxruntime.training.artifacts.generate_artifacts()
API.
This is done by inheriting from the onnxruntime.training.onnxblock.Block
class and implementing the build method.
The following example shows how to implement a custom loss function:
Let’s assume, we want to use a custom loss function with a model. For this example, we assume that our model generates two outputs. And the custom loss function must apply a loss function on each of the outputs and perform a weighted average on the output. Mathematically,
loss = 0.4 * mse_loss1(output1, target1) + 0.6 * mse_loss2(output2, target2)
Since this is a custom loss function, this loss type is not exposed as an enum by LossType enum.
For this, we make use of onnxblock.
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts
# Define a custom loss block that takes in two inputs
# and performs a weighted average of the losses from these
# two inputs.
class WeightedAverageLoss(onnxblock.Block):
def __init__(self):
self._loss1 = onnxblock.loss.MSELoss()
self._loss2 = onnxblock.loss.MSELoss()
self._w1 = onnxblock.blocks.Constant(0.4)
self._w2 = onnxblock.blocks.Constant(0.6)
self._add = onnxblock.blocks.Add()
self._mul = onnxblock.blocks.Mul()
def build(self, loss_input_name1, loss_input_name2):
# The build method defines how the block should be stacked on top of
# loss_input_name1 and loss_input_name2
# Returns weighted average of the two losses
return self._add(
self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")),
self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2"))
)
my_custom_loss = WeightedAverageLoss()
# Load the onnx model
model_path = "model.onnx"
base_model = onnx.load(model_path)
# Define the parameters that need their gradient computed
requires_grad = ["weight1", "bias1", "weight2", "bias2"]
frozen_params = ["weight3", "bias3"]
# Now, we can invoke generate_artifacts with this custom loss function
artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params,
loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW)
# Successful completion of the above call will generate 4 files in the current working directory,
# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, optimizer_model.onnx)
Advanced Usage#
onnxblock is a library that can be used to build complex onnx models by stacking simple blocks on top of each other. An example of this is the ability to build a custom loss function as shown above.
onnxblock also provides a way to build a custom forward only or training (forward + backward) onnx model through the onnxruntime.training.onnxblock.ForwardBlock
and onnxruntime.training.onnxblock.TrainingBlock
classes respectively. These blocks inherit from the base onnxruntime.training.onnxblock.Block
class and provide additional functionality to build inference and training models.