Skip to content

Commit c16a9eb

Browse files
committed
fix: update, fix the modeling_qeff
Signed-off-by: vbaddi <[email protected]>
1 parent 2cb1708 commit c16a9eb

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

QEfficient/base/onnx_transforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,42 @@ def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set)
222222
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)
223223

224224
return False
225+
226+
227+
class RenameFunctionOutputsTransform(OnnxTransform):
228+
"""
229+
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
230+
"""
231+
232+
@classmethod
233+
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
234+
"""
235+
Rename function outputs in decoder layer nodes.
236+
237+
:param model: The ONNX model to transform
238+
:returns: Transformed model and boolean indicating whether transform was applied
239+
"""
240+
graph = model.graph
241+
op_type_to_func_map = {func.name: func for func in model.functions}
242+
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
243+
transformed = False
244+
model_graph_outputs = [val.name for val in model.graph.output]
245+
246+
for node in graph.node:
247+
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
248+
func = op_type_to_func_map.get(node.op_type)
249+
if func is None:
250+
continue
251+
252+
for i, out_name in enumerate(func.output):
253+
if "_InternalRetainedState" in out_name:
254+
transformed = True
255+
tmp = node.output[i]
256+
new_name = func.output[i].replace("Internal", "")
257+
node.output[i] = new_name
258+
259+
# Update graph output name if it exists
260+
if tmp in model_graph_outputs:
261+
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
262+
263+
return model, transformed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
CustomOpTransform,
3232
FP16ClipTransform,
3333
OnnxSlimTransform,
34+
RenameFunctionOutputsTransform,
3435
SplitTensorsTransform,
3536
)
3637
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
@@ -2116,7 +2117,13 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
21162117
SplitGateUpWeightsTransform,
21172118
KVCacheExternalModuleMapperTransform,
21182119
]
2119-
_onnx_transforms = [FP16ClipTransform, CustomOpTransform, OnnxSlimTransform, SplitTensorsTransform]
2120+
_onnx_transforms = [
2121+
FP16ClipTransform,
2122+
CustomOpTransform,
2123+
RenameFunctionOutputsTransform,
2124+
OnnxSlimTransform,
2125+
SplitTensorsTransform,
2126+
]
21202127

21212128
def __init__(
21222129
self,
@@ -2364,7 +2371,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
23642371
for kv in ["key", "value"]:
23652372
example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
23662373
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
2367-
output_names.append(f"past_{kv}.{i}_RetainedState")
2374+
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
23682375

23692376
else:
23702377
# HACK: create common function for this including above if condition code
@@ -2381,8 +2388,8 @@ def export(self, export_dir: Optional[str] = None) -> str:
23812388
pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size"
23822389
for kv in ["key", "value"]:
23832390
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
2384-
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
2385-
output_names.append(f"past_{kv}.{i}_RetainedState")
2391+
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
2392+
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
23862393

23872394
if self.continuous_batching:
23882395
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)

0 commit comments

Comments
 (0)