Skip to content

RF-DETR Segmentation to ONNX #463

@AICTIControl

Description

@AICTIControl

Search before asking

Bug

When I train using rfdetr library, I got excelent results on my predictions and the masks seems to be smooth, as it should be.

However, when I export my RF-DETR Segmentation model (Preview) to ONNX for edge inference, I got good predictions, but poor masks detection.

As you can see in the conversion below, Mask output dimensions are [1, 2600, 150, 150], which means is an image with size 150 * 150. I'm convinced this is the problem, and due to this, the masks prediction from the outputs of ONNX model fit the object, although it has aliasing edges.

Image Image

The previous images shows the contours of a mask for an object detected using ONNX model (left) and base rfdetr library (working on PyTorch, on the right). Keeping this in mind, Is it possible to fix the ONNX output in order to smooth the masks for ONNX exporting model? Is this issue related with mask image size of outputs?

Environment

  1. RF-DETR: 1.3.0
  2. OS: Windows 11
  3. Python: 3.10.16
  4. PyTorch: 2.8.0+cu129
  5. CUDA: 12.9
  6. GPU: NVIDIA RTX 4080 Super

Minimal Reproducible Example

model = rfdetr.RFDETRSegPreview(pretrain_weights='outputs/checkpoint_best_ema.pth', resolution=384)
model.export(output_dir="onnx", infer_dir=None, simplify=False,  backbone_only=False)
Exporting model to ONNX format
PyTorch inference output shapes - Boxes: torch.Size([1, 2600, 4]), Labels: torch.Size([1, 2600, 3]), Masks: torch.Size([1, 2600, 150, 150])
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).

Successfully exported ONNX model: onnx\inference_model.onnx
Successfully exported ONNX model to: onnx\inference_model.onnx
ONNX export completed successfully

Additional

No response

Are you willing to submit a PR?

  • Yes, I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions