-
Notifications
You must be signed in to change notification settings - Fork 60
WIP: Feat: Add ONNX Sub Functions Export Feature #613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dfabf37
2cb1708
c16a9eb
02eaaa8
1fce1d6
7f1d431
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,9 +5,12 @@ | |
| # | ||
| # ---------------------------------------------------------------------------- | ||
|
|
||
| from typing import Optional, Tuple | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import onnx | ||
| import onnxslim | ||
| import torch | ||
| from onnx import ModelProto, external_data_helper, numpy_helper | ||
|
|
||
|
|
||
|
|
@@ -99,3 +102,164 @@ def apply( | |
| current_file_size = tsize | ||
| external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
| return model, transformed | ||
|
|
||
|
|
||
| class OnnxSlimTransform(OnnxTransform): | ||
| """ | ||
| Applies onnx-slim transformations on the given ONNX graph. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def apply( | ||
| cls, | ||
| model: ModelProto, | ||
| *, | ||
| onnx_base_dir: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| :param enable_onnx_slim_transform: If True, applies onnx-slim transformations. | ||
| :param temp_onnx_path: Path to save the slimmed ONNX model. | ||
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can remove it from here. There is a flag called "enable_onnx_slim_transform" lets users decide whether to enable ONNX Slim. We can add a condition in modeling_auto so that this transform is included in the _onnx_transform list only when the flag is enabled.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested this change with GPTOSS, it fails in the onnxslim transform. Discussed with VB that this doesn't help us much. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide the error log? I think there should be a 5% performance gain with onnxslim.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @inisis thanks. you mean ~5% gain in perf.? w/onnxslim There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@abhishek-singh591 Hi, what's that dummy nodes, it should not be created by onnxslim, the removes of identical nodes is generally known as CSE, it reduces extra computation, I thinks it's very useful.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The nodes left behind after removing identical can be considered dummy/orphan nodes. These are nodes that were originally connected as outputs to the identical nodes but, after CSE they no longer have valid connections. Ideally, CSE should rewire the inputs and outputs properly so that no orphan nodes remain right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, onnxslim will remove those dummy nodes automatically.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Really, can you provide me an example, many thanks. |
||
| temp_onnx_path = kwargs.get("temp_onnx_path", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make it as a mandiatory argument? and onnx_base_dir is unused here |
||
| if not temp_onnx_path: | ||
| err_str = "temp_onnx_path is required for onnx-slim transform." | ||
| raise RuntimeError(err_str) | ||
| if onnx_slim_transform: | ||
| transformed = True | ||
| slimmed_model = onnxslim.slim(model) | ||
| onnx.save(slimmed_model, temp_onnx_path) | ||
| return slimmed_model, transformed | ||
| return model, transformed | ||
|
|
||
|
|
||
| class CustomOpTransform(OnnxTransform): | ||
| """ | ||
| Transform to register custom operations and add their function protos to the ONNX model. | ||
| """ | ||
|
|
||
| # Registry of custom operations | ||
| _custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func) | ||
|
|
||
| @classmethod | ||
| def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any): | ||
| """Register a custom operation.""" | ||
| cls._custom_ops[op_name] = (func_class, onnxscript_func) | ||
|
|
||
| @classmethod | ||
| def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| Apply custom op registration and add function protos to the model. | ||
|
|
||
| :param model: The ONNX model to transform | ||
| :param opset_version: ONNX opset version for symbolic registration | ||
| :returns: Transformed model and success flag | ||
| """ | ||
| transformed = False | ||
|
|
||
| # Register all custom op symbolic functions with torch.onnx | ||
| for op_name, (func_class, _) in cls._custom_ops.items(): | ||
| if hasattr(func_class, "symbolic"): | ||
| torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version) | ||
|
|
||
| # Add function protos for custom ops that are used in the model | ||
| used_protos = cls._get_function_protos_for_model(model) | ||
|
|
||
| for proto in used_protos: | ||
| # Check if proto already exists to avoid duplicates | ||
| proto_name = proto.name | ||
| if not any(func.name == proto_name for func in model.functions): | ||
| model.functions.append(proto) | ||
| transformed = True | ||
|
|
||
| return model, transformed | ||
|
|
||
| @classmethod | ||
| def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]: | ||
| """Get function protos for custom ops that are actually used in the model.""" | ||
| used_protos = [] | ||
|
|
||
| # Get all node op_types in the model | ||
| used_op_types = set() | ||
| for node in model.graph.node: | ||
| used_op_types.add(node.op_type) | ||
|
|
||
| # Also check function calls | ||
| for func in model.functions: | ||
| for node in func.node: | ||
| used_op_types.add(node.op_type) | ||
|
|
||
| # Check which custom ops are actually used | ||
| for op_name, (func_class, onnxscript_func) in cls._custom_ops.items(): | ||
| # Check if the custom op is referenced in the model | ||
| if cls._is_custom_op_used(model, op_name, used_op_types): | ||
| proto = onnxscript_func.to_function_proto() | ||
| used_protos.append(proto) | ||
|
|
||
| return used_protos | ||
|
|
||
| @classmethod | ||
| def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool: | ||
| """Check if a custom op is used in the model.""" | ||
| # Check if the op_name appears in node op_types | ||
| if op_name in used_op_types: | ||
| return True | ||
|
|
||
| # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm") | ||
| custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}" | ||
| if custom_op_pattern in used_op_types: | ||
| return True | ||
|
|
||
| # Heuristic checks based on op type | ||
| if "RMSNorm" in op_name: | ||
| # Check if any RMSNorm-related ops are present | ||
| return any("RMSNorm" in op_type for op_type in used_op_types) | ||
|
|
||
| if "Ctx" in op_name: | ||
| # Check if Gather/Scatter operations are present (indicating KV cache usage) | ||
| return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types) | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| class RenameFunctionOutputsTransform(OnnxTransform): | ||
| """ | ||
| Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | ||
| """ | ||
| Rename function outputs in decoder layer nodes. | ||
|
|
||
| :param model: The ONNX model to transform | ||
| :returns: Transformed model and boolean indicating whether transform was applied | ||
| """ | ||
| graph = model.graph | ||
| op_type_to_func_map = {func.name: func for func in model.functions} | ||
| decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"] | ||
| transformed = False | ||
| model_graph_outputs = [val.name for val in model.graph.output] | ||
| layer_index = 0 | ||
| for node in graph.node: | ||
| if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns): | ||
| func = op_type_to_func_map.get(node.op_type) | ||
| if func is None: | ||
| continue | ||
|
|
||
| for i, out_name in enumerate(func.output): | ||
| if "_InternalRetainedState" in out_name: | ||
| transformed = True | ||
| tmp = node.output[i] | ||
| if "key" in out_name: | ||
| new_name = f"past_key.{layer_index}_RetainedState" | ||
| elif "value" in out_name: | ||
| new_name = f"past_value.{layer_index}_RetainedState" | ||
| node.output[i] = new_name | ||
| # Update graph output name if it exists | ||
| if tmp in model_graph_outputs: | ||
| model.graph.output[model_graph_outputs.index(tmp)].name = new_name | ||
| layer_index = layer_index + 1 | ||
| return model, transformed | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device | |
| self.base = base | ||
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||
|
|
||
| if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling: | ||
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change part of ONNX Sub Functions?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, but the correct modeling representation changes. |
||
| factor = config.rope_scaling["factor"] | ||
| inv_freq /= factor | ||
| self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,13 @@ | |
|
|
||
| import QEfficient | ||
| from QEfficient.base.modeling_qeff import QEFFBaseModel | ||
| from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform | ||
| from QEfficient.base.onnx_transforms import ( | ||
| CustomOpTransform, | ||
| FP16ClipTransform, | ||
| OnnxSlimTransform, | ||
| RenameFunctionOutputsTransform, | ||
| SplitTensorsTransform, | ||
| ) | ||
| from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform | ||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
| from QEfficient.generation.text_generation_inference import ( | ||
|
|
@@ -2111,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): | |
| SplitGateUpWeightsTransform, | ||
| KVCacheExternalModuleMapperTransform, | ||
| ] | ||
| _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] | ||
| _onnx_transforms = [ | ||
| FP16ClipTransform, | ||
| CustomOpTransform, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to apply the CustomOpTransform again after export? |
||
| RenameFunctionOutputsTransform, | ||
| OnnxSlimTransform, | ||
| SplitTensorsTransform, | ||
| ] | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -2359,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
| for kv in ["key", "value"]: | ||
| example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| else: | ||
| # HACK: create common function for this including above if condition code | ||
|
|
@@ -2377,7 +2389,7 @@ def export(self, export_dir: Optional[str] = None) -> str: | |
| for kv in ["key", "value"]: | ||
| example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") | ||
|
|
||
| if self.continuous_batching: | ||
| example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction do we need to do the monkey patching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction then monkey patching is not required but doing this won't harm execution, we have checked the generation w/o subfunction and monkey patching, though we can put a condition for this too.