| | |
| |
|
| | from typing import Tuple |
| |
|
| | import argparse |
| | import onnxruntime |
| | import os |
| | import sys |
| | import time |
| | import torch |
| | import torchvision.datasets as datasets |
| | import torchvision.transforms as transforms |
| |
|
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--onnx_model", default="model.onnx", help="Input onnx model") |
| | parser.add_argument( |
| | "--data_dir", |
| | default="/workspace/dataset/imagenet", |
| | help="Directory of dataset") |
| | parser.add_argument( |
| | "--ipu", |
| | action="store_true", |
| | help="Use IPU for inference.", |
| | ) |
| | parser.add_argument( |
| | "--provider_config", |
| | type=str, |
| | default="vaip_config.json", |
| | help="Path of the config file for seting provider_options.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| |
|
| | def __init__(self, name, fmt=':f'): |
| | self.name = name |
| | self.fmt = fmt |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | def __str__(self): |
| | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| | return fmtstr.format(**self.__dict__) |
| |
|
| | def accuracy(output: torch.Tensor, |
| | target: torch.Tensor, |
| | topk: Tuple[int] = (1,)) -> Tuple[float]: |
| | """Computes the accuracy over the k top predictions for the specified values of k. |
| | Args: |
| | output: Prediction of the model. |
| | target: Ground truth labels. |
| | topk: Topk accuracy to compute. |
| | Returns: |
| | Accuracy results according to 'topk'. |
| | """ |
| |
|
| | with torch.no_grad(): |
| | maxk = max(topk) |
| | batch_size = target.size(0) |
| |
|
| | _, pred = output.topk(maxk, 1, True, True) |
| | pred = pred.t() |
| | correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| |
|
| | res = [] |
| | for k in topk: |
| | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) |
| | res.append(correct_k.mul_(100.0 / batch_size)) |
| | return res |
| |
|
| | def prepare_data_loader(data_dir: str, |
| | batch_size: int = 1, |
| | workers: int = 8) -> torch.utils.data.DataLoader: |
| | """Returns a validation data loader of ImageNet by given `data_dir`. |
| | Args: |
| | data_dir: Directory where images stores. There must be a subdirectory named |
| | 'validation' that stores the validation set of ImageNet. |
| | batch_size: Batch size of data loader. |
| | workers: How many subprocesses to use for data loading. |
| | Returns: |
| | An object of torch.utils.data.DataLoader. |
| | """ |
| |
|
| | valdir = os.path.join(data_dir, 'validation') |
| |
|
| | normalize = transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | val_dataset = datasets.ImageFolder( |
| | valdir, |
| | transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | normalize, |
| | ])) |
| |
|
| | return torch.utils.data.DataLoader( |
| | val_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=workers, |
| | pin_memory=True) |
| |
|
| | def val_imagenet(): |
| | """Validate ONNX model on ImageNet dataset.""" |
| | print(f'Current onnx model: {args.onnx_model}') |
| |
|
| | if args.ipu: |
| | providers = ["VitisAIExecutionProvider"] |
| | provider_options = [{"config_file": args.provider_config}] |
| | else: |
| | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| | provider_options = None |
| | ort_session = onnxruntime.InferenceSession( |
| | args.onnx_model, providers=providers, provider_options=provider_options) |
| |
|
| | val_loader = prepare_data_loader(args.data_dir) |
| |
|
| | top1 = AverageMeter('Acc@1', ':6.2f') |
| | top5 = AverageMeter('Acc@5', ':6.2f') |
| |
|
| | start_time = time.time() |
| | val_loader = tqdm(val_loader, file=sys.stdout) |
| | with torch.no_grad(): |
| | for batch_idx, (images, targets) in enumerate(val_loader): |
| | inputs, targets = images.numpy().transpose(0, 2, 3, 1), targets |
| | ort_inputs = {ort_session.get_inputs()[0].name: inputs} |
| |
|
| | outputs = ort_session.run(None, ort_inputs) |
| | outputs = torch.from_numpy(outputs[0]) |
| |
|
| | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) |
| | top1.update(acc1, images.size(0)) |
| | top5.update(acc5, images.size(0)) |
| |
|
| | current_time = time.time() |
| | print('Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.format( |
| | float(top1.avg), float(top5.avg), (current_time - start_time))) |
| |
|
| | return top1.avg, top5.avg |
| |
|
| | if __name__ == '__main__': |
| | val_imagenet() |