Skip to content

Commit 0f7c870

Browse files
committed
oops
1 parent a2819fb commit 0f7c870

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/levanter/models/gpt2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def init(Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2Embeddings":
347347

348348
@named_call
349349
def embed(self, input_ids, *, key, pos_ids: NamedArray):
350-
jax.debug.print("input_ids: has_nan={nan}", nan=input_ids.has_nan)
351-
jax.debug.print("token_embeddings.weight: has_nan={nan}", nan=self.token_embeddings.weight.array.has_nan)
350+
jax.debug.print("input_ids: has_nan={nan}", nan=jnp.any(jnp.isnan(input_ids.array)))
351+
jax.debug.print("token_embeddings.weight: has_nan={nan}", nan=jnp.any(jnp.isnan(self.token_embeddings.weight.array)))
352352

353353
input_embeds = self.token_embeddings(input_ids)
354354
input_embeds_norm = jnp.linalg.norm(input_embeds.array)

0 commit comments

Comments
 (0)