PyTorch and ONNX Flow for Training#

Goals#

  • Learn how to re-training a model using PyTorch

  • Learn how to export a trained model to ONNX

  • Learn how to quantize an ONNX model to run inference on the NPU

References#

Ryzen AI Software Platform

Vitis AI Execution Provider

CIFAR10


This is not currently supported on the Linux release of Riallto.

Running this re-training notebook will generate model files that will overwrite the existing trained quantized file in the onnx folder.

Please make sure you rename any existing model files in the onnx folder to save them.

The names of the model files that will be written are the following:

  1. The trained ResNet-50 model on the CIFAR-10 dataset is: onnx\resnet_trained_for_cifar10.pt.

  2. The trained ResNet-50 model on the CIFAR-10 dataset in ONNX format is: onnx\resnet_trained_for_cifar10.onnx.

  3. The trained quantized ResNet-50 model on the CIFAR-10 dataset in ONNX format is: onnx/resnet.qdq.U8S8.onnx


Step 1: Import Packages#

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as transforms
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.datasets import CIFAR10

import vai_q_onnx
from onnxruntime.quantization import CalibrationDataReader, QuantType, QuantFormat
import random

Step 2: Prepare the Model#

Let us retrain the ResNet-50 model from PyTorch Hub using the CIFAR-10 dataset.

The CIFAR-10 dataset is used to retrain the default model using the transfer learning technique.

Make sure that the CIFAR-10 dataset is downloaded. For steps refer to the previous notebook.

Load model for re-training using transfer learning#

The pre-trained ResNet-50 model trained on 1,000 class ImageNet dataset by default has fully connected (FC) layer of output size 1,000. This means that it produces a 1,000-dimensional vector, where each dimension corresponds to a class in the ImageNet dataset.

We use transfer learning to select a set of pre-trained weights for the model and then customize the model’s classifier by replacing its FC layers. The modification includes adding two linear layers, one with 2,048 input features and 64 output features, followed by a ReLU activation function, and another linear layer with 64 input features and 10 output features. This adaptation transforms the ResNet-50 model into a classifier suitable for a specific task with 10 classes.

# License 1 (see end of notebook)

def load_resnet_model():
    weights = ResNet50_Weights.DEFAULT
    resnet = resnet50(weights=weights)
    resnet.fc = torch.nn.Sequential(torch.nn.Linear(2048, 64), torch.nn.ReLU(inplace=True), torch.nn.Linear(64, 10))
    return resnet


# For updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

Model re-training#

Define the CIFAR-10 dataset directory

global models_dir, data_dir
models_dir = ".\\onnx"
data_dir= ".\\onnx\\data"

The training process runs over 500 images with a batch_size of 100, i.e., over the total 50,000 images in the train set.

The training process takes approximately 10 minutes to complete each epoch. Number of epochs can be varied to optimize the accuracy of the model.

At the end of this process, we will save the trained model as an ONNX model and then we will also quantize this model.

# License 1 (see end of notebook)

def prepare_model(num_epochs=0):
    # Seed everything to 0
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyper-parameters
    num_epochs = num_epochs
    learning_rate = 0.001

    # Image preprocessing modules
    transform = transforms.Compose(
        [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
    )

    # CIFAR-10 dataset
    train_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=transform, download=False)
    test_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

    model = load_resnet_model().to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    curr_lr = learning_rate
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i + 1) % 100 == 0:
                print(
                    "Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(
                        epoch + 1, num_epochs, i + 1, total_step, loss.item()
                    )
                )
        # Decay learning rate
        if (epoch + 1) % 20 == 0:
            curr_lr /= 3
            update_lr(optimizer, curr_lr)

    # Test the model
    model.eval()
    if num_epochs:
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / total
            print("Accuracy of the model on the test images: {} %".format(accuracy))
    return model
# Run training
model = prepare_model(num_epochs=1)
Epoch [1/1], Step [100/500] Loss: 1.0127
Epoch [1/1], Step [200/500] Loss: 0.9769
Epoch [1/1], Step [300/500] Loss: 0.8371
Epoch [1/1], Step [400/500] Loss: 0.4864
Epoch [1/1], Step [500/500] Loss: 0.7878
Accuracy of the model on the test images: 77.61 %

Save the trained PyTorch model by running the following cell:

model.to("cpu")
model_path = f"{models_dir}/resnet_trained_for_cifar10.pt"
torch.save(model, model_path)

After completing the training process, observe the following output:

  • The trained ResNet-50 model on the CIFAR-10 dataset is saved at the following location: onnx/resnet_trained_for_cifar10.pt.


Step 3: Convert Model to ONNX Format#

Run the following cell to save the trained model as an ONNX model:

def save_onnx_model(model):
    dummy_inputs = torch.randn(1, 3, 32, 32)
    input_names = ['input']
    output_names = ['output']
    dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    onnx_model_path = f"{models_dir}/resnet_trained_for_cifar10.onnx"
    torch.onnx.export(
        model,
        dummy_inputs,
        onnx_model_path,
        export_params=True,
        opset_version=13,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )
# Save model
save_onnx_model(model)

After completing this process, observe the following output:

  • The trained ResNet-50 model on the CIFAR-10 dataset is saved at the following location in ONNX format: onnx/resnet_trained_for_cifar10.onnx.

Visualize the ONNX model#

Generated and adapted using Netron

Netron is a viewer for neural network, deep learning and machine learning models.

Note this is an image of the default model we are using. If you have modified or re-trained your model, please visit Netron to generate a graph for your model.

onnx_graph
ResNet-50 ONNX model

Step 4: Quantize the Model#

Quantizing AI models from floating-point to 8-bit integers reduces computational power and the memory footprint required for inference. For model quantization, you can either use Vitis AI quantizer or Microsoft Olive. This example utilizes the Vitis AI ONNX quantizer workflow.

This will generate a quantized model using QDQ quant format and UInt8 activation type and Int8 weight type. After the run is completed, the quantized ONNX model resnet.qdq.U8S8.onnx is saved to onnx/resnet.qdq.U8S8.onnx.

For more information on representation of quantized ONNX models (e.g., QDQ quant format, UInt8 activation type and Int8 weight type) see here

The quantize_static function applies static quantization to the model.

from onnxruntime.quantization import QuantFormat, QuantType
import vai_q_onnx

vai_q_onnx.quantize_static(
    input_model_path,
    output_model_path,
    dr,
    quant_format=QuantFormat.QDQ,
    calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
    enable_dpu=True, 
    extra_options={'ActivationSymmetric': True} 
)

The parameters of this function are:

  • input_model_path: (String) The file path of the model to be quantized.

  • output_model_path: (String) The file path where the quantized model will be saved.

  • dr: (Object or None) Calibration data reader that enumerates the calibration data and producing inputs for the original model. In this example, CIFAR10 dataset is used for calibration during the quantization process.

  • quant_format: (String) Specifies the quantization format of the model. In this example we have used the QDQ quant format.

  • calibrate_method:(String) In this example this parameter is set to vai_q_onnx.PowerOfTwoMethod.MinMSE to apply power-of-2 scale quantization.

  • activation_type: (String) Data type of activation tensors after quantization. In this example, it is set to QUInt8 (Quantized Unsigned Int 8).

  • weight_type: (String) Data type of weight tensors after quantization. In this example, it is set to QInt8 (Quantized Int 8).

Run the following cell to define the calibration data reader (resnet_calibration_reader):

# License 2 (see end of notebook)

class CIFAR10DataSet:
    def __init__(
        self,
        data_dir,
        **kwargs,
    ):
        super().__init__()
        self.train_path = data_dir
        self.vld_path = data_dir
        self.setup("fit")

    def setup(self, stage: str):
        transform = transforms.Compose(
            [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
        )
        self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False)
        self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False)


class PytorchResNetDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        sample = self.dataset[index]
        input_data = sample[0]
        label = sample[1]
        return input_data, label


def create_dataloader(data_dir, batch_size):
    cifar10_dataset = CIFAR10DataSet(data_dir)
    _, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000])
    benchmark_dataloader = DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True)
    return benchmark_dataloader


class ResnetCalibrationDataReader(CalibrationDataReader):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        self.iterator = iter(create_dataloader(data_dir, batch_size))

    def get_next(self) -> dict:
        try:
            images, labels = next(self.iterator)
            return {"input": images.numpy()}
        except Exception:
            return None


def resnet_calibration_reader(data_dir, batch_size=16):
    return ResnetCalibrationDataReader(data_dir, batch_size=batch_size)

Run the following cell to quantize and save the model:

# License 2 (see end of notebook)

# `input_model_path` is the path to the original, unquantized ONNX model.
input_model_path = "onnx/resnet_trained_for_cifar10.onnx"

# `output_model_path` is the path where the quantized model will be saved.
output_model_path = "onnx/resnet.qdq.U8S8.onnx"

# `calibration_dataset_path` is the path to the dataset used for calibration during quantization.
calibration_dataset_path = "onnx/data/"

# `dr` (Data Reader) is an instance of ResNet50DataReader, which is a utility class that 
# reads the calibration dataset and prepares it for the quantization process.
dr = resnet_calibration_reader(calibration_dataset_path)

vai_q_onnx.quantize_static(
    input_model_path,
    output_model_path,
    dr,
    quant_format=QuantFormat.QDQ,
    calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
    enable_dpu=True,
    extra_options={'ActivationSymmetric': True} 
)
print('Calibrated and quantized model saved at:', output_model_path)
Finding optimal threshold for each tensor using PowerOfTwoMethod.MinMSE algorithm ...
Calibrated and quantized model saved at: onnx/resnet.qdq.U8S8.onnx

After completing the quantization process, observe the following output:

  • The quantized ResNet-50 model on the CIFAR-10 dataset is saved at the following location in ONNX format: onnx/resnet.qdq.U8S8.onnx.


Step 5: Deploy the Model on NPU for Inference#

To run Inference using the model generated in this notebook please refer to the PyTorch ONNX Inference notebook.


Licenses#

License 1

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

License 2

#################################################################################  
# License
# Ryzen AI is licensed under `MIT License <https://github.com/amd/ryzen-ai-documentation/blob/main/License>`_ . Refer to the `LICENSE File <https://github.com/amd/ryzen-ai-documentation/blob/main/License>`_ for the full license text and copyright notice.

Copyright© 2023 AMD, Inc
SPDX-License-Identifier: MIT