Skip to content

Commit c24e655

Browse files
saugenstFlax Authors
authored andcommitted
Fix bug when raising ScopeParamNotFoundError.
The order of 'actual value shape' vs. 'expected value shape based on initializer' was mixed up... the actual value should come first and the initialized-based expectation second. This CL changes order to be correct. This removes a point of confusion when interpreting the error message when raised (expected vs. actual). PiperOrigin-RevId: 799698584
1 parent c45197e commit c24e655

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

flax/core/scope.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -948,17 +948,17 @@ def param(
948948
# catch it with an error message.
949949
# NOTE: We could consider moving this to `self.`
950950
abs_value = jax.eval_shape(
951-
lambda: init_fn(random.key(0), *init_args, **init_kwargs)
951+
lambda: init_fn(random.key(0), *init_args, **init_kwargs)
952952
)
953953
abs_value_flat = jax.tree_util.tree_leaves(abs_value)
954954
value_flat = jax.tree_util.tree_leaves(value)
955955
for val, abs_val in zip(value_flat, abs_value_flat):
956-
# NOTE: We could check dtype consistency here as well but it's
957-
# usefuleness is less obvious. We might intentionally change the dtype
958-
# for inference to a half float type for example.
956+
# NOTE: We could check dtype consistency here as well but its usefulness
957+
# is less obvious. We might intentionally change the dtype for inference
958+
# to a half float type for example.
959959
if np.shape(val) != np.shape(abs_val):
960960
raise errors.ScopeParamShapeError(
961-
name, self.path_text, np.shape(abs_val), np.shape(val)
961+
name, self.path_text, np.shape(val), np.shape(abs_val)
962962
)
963963
else:
964964
if not self.is_mutable_collection('params'):

flax/errors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ def __call__(self, x):
283283

284284
def __init__(self, param_name, scope_path, value_shape, init_shape):
285285
super().__init__(
286-
f'Initializer expected to generate shape {init_shape} '
287-
f'but got shape {value_shape} instead for parameter '
288-
f'"{param_name}" in "{scope_path}".'
286+
f'For parameter "{param_name}" in "{scope_path}", the given '
287+
f'initializer is expected to generate shape {init_shape}, but the '
288+
f'existing parameter it received has shape {value_shape}.'
289289
)
290290

291291

tests/core/core_scope_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ def f(scope):
123123
scope.param('test', nn.initializers.ones_init(), (4,))
124124

125125
msg = (
126-
r'Initializer expected to generate shape \(2,\) but got shape \(4,\)'
127-
r' instead for parameter "test" in "/"'
126+
r'For parameter "test" in "/", the given initializer is expected to'
127+
r' generate shape \(4,\), but the existing parameter it received has'
128+
r' shape \(2,\).'
128129
)
129130
with self.assertRaisesRegex(errors.ScopeParamShapeError, msg):
130131
apply(f)(freeze({'params': {'test': np.ones((2,))}}))

0 commit comments

Comments
 (0)