Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@
import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.patches import apply_torch_patches, is_patched

# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient


# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

# Apply patches
# TODO: Find a better way to do this, this is temp. fix.
apply_torch_patches()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not enabling subfunction do we need to do the monkey patching?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not enabling subfunction then monkey patching is not required but doing this won't harm execution, we have checked the generation w/o subfunction and monkey patching, though we can put a condition for this too.



def check_qaic_sdk():
"""Check if QAIC SDK is installed"""
Expand Down Expand Up @@ -70,6 +74,8 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"apply_torch_patches",
"is_patched",
]

else:
Expand Down
13 changes: 12 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
Expand Down Expand Up @@ -243,6 +246,11 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm)
CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter)
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
Expand All @@ -252,6 +260,8 @@ def _export(
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
export_modules_as_functions=decoder_layer_classes,
do_constant_folding=True,
**export_kwargs,
)
logger.info("PyTorch export successful")
Expand All @@ -261,6 +271,7 @@ def _export(
model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
"temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand Down
166 changes: 165 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
#
# ----------------------------------------------------------------------------

from typing import Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import onnx
import onnxslim
import torch
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -99,3 +102,164 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform(OnnxTransform):
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
:param temp_onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it from here. There is a flag called "enable_onnx_slim_transform" lets users decide whether to enable ONNX Slim. We can add a condition in modeling_auto so that this transform is included in the _onnx_transform list only when the flag is enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this change with GPTOSS, it fails in the onnxslim transform. Discussed with VB that this doesn't help us much.
Lets not add extra package dependency if it has limited use.
Let's remove onnxslim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide the error log? I think there should be a 5% performance gain with onnxslim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inisis thanks. you mean ~5% gain in perf.? w/onnxslim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about GPT OSS but for Qwen 2.5 VL we observed that onnxslim removes identical nodes which lead to creation of dummy nodes.

@abhishek-singh591 Hi, what's that dummy nodes, it should not be created by onnxslim, the removes of identical nodes is generally known as CSE, it reduces extra computation, I thinks it's very useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nodes left behind after removing identical can be considered dummy/orphan nodes. These are nodes that were originally connected as outputs to the identical nodes but, after CSE they no longer have valid connections. Ideally, CSE should rewire the inputs and outputs properly so that no orphan nodes remain right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.

Really, can you provide me an example, many thanks.

temp_onnx_path = kwargs.get("temp_onnx_path", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it as a mandiatory argument? and onnx_base_dir is unused here

if not temp_onnx_path:
err_str = "temp_onnx_path is required for onnx-slim transform."
raise RuntimeError(err_str)
if onnx_slim_transform:
transformed = True
slimmed_model = onnxslim.slim(model)
onnx.save(slimmed_model, temp_onnx_path)
return slimmed_model, transformed
return model, transformed


class CustomOpTransform(OnnxTransform):
"""
Transform to register custom operations and add their function protos to the ONNX model.
"""

# Registry of custom operations
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)

@classmethod
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
"""Register a custom operation."""
cls._custom_ops[op_name] = (func_class, onnxscript_func)

@classmethod
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
"""
Apply custom op registration and add function protos to the model.

:param model: The ONNX model to transform
:param opset_version: ONNX opset version for symbolic registration
:returns: Transformed model and success flag
"""
transformed = False

# Register all custom op symbolic functions with torch.onnx
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)

# Add function protos for custom ops that are used in the model
used_protos = cls._get_function_protos_for_model(model)

for proto in used_protos:
# Check if proto already exists to avoid duplicates
proto_name = proto.name
if not any(func.name == proto_name for func in model.functions):
model.functions.append(proto)
transformed = True

return model, transformed

@classmethod
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
"""Get function protos for custom ops that are actually used in the model."""
used_protos = []

# Get all node op_types in the model
used_op_types = set()
for node in model.graph.node:
used_op_types.add(node.op_type)

# Also check function calls
for func in model.functions:
for node in func.node:
used_op_types.add(node.op_type)

# Check which custom ops are actually used
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
# Check if the custom op is referenced in the model
if cls._is_custom_op_used(model, op_name, used_op_types):
proto = onnxscript_func.to_function_proto()
used_protos.append(proto)

return used_protos

@classmethod
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
"""Check if a custom op is used in the model."""
# Check if the op_name appears in node op_types
if op_name in used_op_types:
return True

# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
if custom_op_pattern in used_op_types:
return True

# Heuristic checks based on op type
if "RMSNorm" in op_name:
# Check if any RMSNorm-related ops are present
return any("RMSNorm" in op_type for op_type in used_op_types)

if "Ctx" in op_name:
# Check if Gather/Scatter operations are present (indicating KV cache usage)
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)

return False


class RenameFunctionOutputsTransform(OnnxTransform):
"""
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
"""

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
Rename function outputs in decoder layer nodes.

:param model: The ONNX model to transform
:returns: Transformed model and boolean indicating whether transform was applied
"""
graph = model.graph
op_type_to_func_map = {func.name: func for func in model.functions}
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
transformed = False
model_graph_outputs = [val.name for val in model.graph.output]
layer_index = 0
for node in graph.node:
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
func = op_type_to_func_map.get(node.op_type)
if func is None:
continue

for i, out_name in enumerate(func.output):
if "_InternalRetainedState" in out_name:
transformed = True
tmp = node.output[i]
if "key" in out_name:
new_name = f"past_key.{layer_index}_RetainedState"
elif "value" in out_name:
new_name = f"past_value.{layer_index}_RetainedState"
node.output[i] = new_name
# Update graph output name if it exists
if tmp in model_graph_outputs:
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
layer_index = layer_index + 1
return model, transformed
41 changes: 29 additions & 12 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# -----------------------------------------------------------------------------


import os
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -24,6 +25,30 @@
)


def _get_invalid_idx_value():
"""
Get the appropriate invalid index value for CtxGather operations.

For ONNX export with functions, we use 0 to avoid INT32_MAX constants
that cause issues when functions are inlined at runtime.

Returns:
int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise)
"""
if torch.onnx.is_in_onnx_export():
# Check if ONNX functions are being used
use_onnx_functions = os.environ.get("QEFF_USE_ONNX_FUNCTIONS", "false").lower() == "true"
if use_onnx_functions:
# For ONNX functions: use 0 to avoid function inlining issues
return 0
else:
# For regular ONNX export: use INT32_MAX as before
return torch.iinfo(torch.int32).max
else:
# For runtime: use 0
return 0


class QEffDynamicLayer(DynamicLayer):
def read_only(self, cache_kwargs):
"""
Expand All @@ -45,10 +70,7 @@ def read_only(self, cache_kwargs):
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

Expand Down Expand Up @@ -142,10 +164,7 @@ def update(
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
Expand Down Expand Up @@ -418,10 +437,8 @@ def update(
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = _get_invalid_idx_value()
print(f"value of INVALID IDX VALUE is {invalid_idx_value}")
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling:
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change part of ONNX Sub Functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but the correct modeling representation changes.

factor = config.rope_scaling["factor"]
inv_freq /= factor
self.register_buffer("inv_freq", inv_freq, persistent=False)
Expand Down
20 changes: 16 additions & 4 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import (
CustomOpTransform,
FP16ClipTransform,
OnnxSlimTransform,
RenameFunctionOutputsTransform,
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand Down Expand Up @@ -2111,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
SplitGateUpWeightsTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [
FP16ClipTransform,
CustomOpTransform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to apply the CustomOpTransform again after export?

RenameFunctionOutputsTransform,
OnnxSlimTransform,
SplitTensorsTransform,
]

def __init__(
self,
Expand Down Expand Up @@ -2359,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
output_names.append(f"past_{kv}.{i}_RetainedState")
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


else:
# HACK: create common function for this including above if condition code
Expand All @@ -2377,7 +2389,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
output_names.append(f"past_{kv}.{i}_RetainedState")
output_names.append(f"past_{kv}.{i}_InternalRetainedState")

if self.continuous_batching:
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
Expand Down
Loading
Loading