Skip to content

Conversation

@jyc
Copy link
Contributor

@jyc jyc commented Nov 19, 2025

Description

While using train_lm I hit this error about the batch to be evaluated being on a different device than the _jit_loglikelihood function used to compute log-likelikelihoods (stack trace below). I think the fix is to call hax.shard with the same ResourceMapping that we use to create _jit_loglikelihood.

Traceback (most recent
 call last):                                                                                              | 0/360448 [00:00<?, ?tok/s]
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/main/train_lm.py", line 344, in <module>
    levanter.config.main(main)()
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/config.py", line 110, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/main/train_lm.py", line 331, in main
    last_info = trainer.train(state, train_loader)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/trainer.py", line 540, in train
    for info in self.training_steps(state, train_loader):
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/trainer.py", line 528, in training_steps
    info = self.train_step(state, example)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/trainer.py", line 505, in train_step
    self.run_hooks(info)
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/trainer.py", line 334, in run_hooks
    self.hooks.run_hooks(info, force=force)
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/trainer.py", line 127, in run_hooks
    hook.fn.on_step(info, force=force)
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/callbacks/_core.py", line 59, in on_step
    self.fn(info)
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/eval_harness.py", line 1495, in lm_eval_harness
    outputs = _actually_run_eval_harness(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/eval_harness.py", line 1263, in _actually_run_eval_harness
    outputs = evaluator.evaluate(
              ^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/.venv/lib/python3.12/site-packages/lm_eval/utils.py", line 456, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/.venv/lib/python3.12/site-packages/lm_eval/evaluator.py", line 592, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/eval_harness.py", line 569, in loglikelihood
    out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(batch)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/eval_harness.py", line 294, in dispatch_loglikelihood
    return self.process_loglikelihood(packed_request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/levanter/src/levanter/eval_harness.py", line 288, in process_loglikelihood
    out = self._jit_loglikelihood(self.model, packed_request)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/haliax/src/haliax/partitioning.py", line 388, in __call__
    return self._call(False, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/.venv/lib/python3.12/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jyc/projects/marin/lib/haliax/src/haliax/partitioning.py", line 464, in _call
    out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Received incompatible devices for jitted computation. Got argument dynamic_reserved[1][0][1].tokens[0] of _LmEvalHarnessWorker.__init__.<locals>._eval_loglikelihood
 with shape int32[32,1024] and device ids [0] on platform CPU and jit's context mesh with device ids [0] on platform GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant