-
Notifications
You must be signed in to change notification settings - Fork 614
Open
Description
The torch.operator "onnx.QLinearAdd" operation is marked as illegal and fails during legalization when compiling torch-mlir generated MLIR with IREE. This occurs when processing a ResNet18 INT8 quantized model.
resnet18_int8.mlir:157:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
%153 = torch.operator "onnx.QLinearAdd"(%152, %15, %14, %150, %1, %0, %21, %20) : (!torch.vtensor<[?,64,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>, !torch.vtensor<[?,64,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.none
^
resnet18_int8.mlir:157:12: note: see current operation: %417 = "torch.operator"(%416, %31, %29, %350, %3, %1, %43, %41) <{name = "onnx.QLinearAdd"}> : (!torch.vtensor<[?,64,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>, !torch.vtensor<[?,64,56,56],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.none
Metadata
Metadata
Assignees
Labels
No labels