Skip to content

Commit 02eaaa8

Browse files
committed
nit: add rename_function fix
Signed-off-by: Vinayak Baddi <[email protected]>
1 parent c16a9eb commit 02eaaa8

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

QEfficient/base/onnx_transforms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
242242
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
243243
transformed = False
244244
model_graph_outputs = [val.name for val in model.graph.output]
245-
245+
layer_index = 0
246246
for node in graph.node:
247247
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
248248
func = op_type_to_func_map.get(node.op_type)
@@ -253,11 +253,13 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
253253
if "_InternalRetainedState" in out_name:
254254
transformed = True
255255
tmp = node.output[i]
256-
new_name = func.output[i].replace("Internal", "")
256+
if "key" in out_name:
257+
new_name = f"past_key.{layer_index}_RetainedState"
258+
elif "value" in out_name:
259+
new_name = f"past_value.{layer_index}_RetainedState"
257260
node.output[i] = new_name
258-
259261
# Update graph output name if it exists
260262
if tmp in model_graph_outputs:
261263
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
262-
264+
layer_index = layer_index + 1
263265
return model, transformed

0 commit comments

Comments
 (0)