Skip to content

Commit 7eba011

Browse files
Cristian GarciaFlax Authors
authored andcommitted
revert is_leaf logic in _check_carry_same_references
PiperOrigin-RevId: 799324513
1 parent a1451a0 commit 7eba011

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

flax/nnx/transforms/iteration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,11 @@ def check_carry_same_references(key_path, arg, out):
652652
)
653653

654654
jax.tree_util.tree_map_with_path(
655-
check_carry_same_references, carry_arg, carry_arg_out, is_leaf=graph.is_graph_node
655+
check_carry_same_references,
656+
carry_arg,
657+
carry_arg_out,
658+
is_leaf=lambda x: graph.is_graph_node(x)
659+
and not isinstance(x, variablelib.Variable),
656660
)
657661

658662
def _extract_graphdefs(
@@ -1308,9 +1312,6 @@ def scan_wrapper(*args, **kwargs):
13081312
return scan_wrapper # type: ignore
13091313

13101314

1311-
1312-
1313-
13141315
# -------------------------------
13151316
# while_loop
13161317
# -------------------------------

0 commit comments

Comments
 (0)