Skip to content

Conversation

@BrendanGraham14
Copy link

@BrendanGraham14 BrendanGraham14 commented Sep 17, 2025

Description

This PR adds forward and backward pass pallas kernels for fused cross-entropy loss. Supports batch+seq+vocab blockwise parallelism.

A few things to note:

  • I haven't plumbed the batch_block_size and seq_block_size params up to LmConfig
  • I haven't run the kernels on tpu / gpu - only cpu (with interpret=True).

Unit test coverage

Unit tests exist in test_loss.py. Also added coverage for logit_soft_cap and batch+seq+vocab parallelism.

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I'll try to run it and see what speed looks like for a llama 3 model!

@dlwh
Copy link
Member

dlwh commented Sep 21, 2025

ok i tried it out. there are a bunch of small problems and some larger problems.

The smaller problem are:

  1. The dtypes of the intermediates were bfloat16 and pallas won't allow you to store bfloat16 in float32 refs without an explicit cast. those computations should be done in float32 anyway, so I fixed that.
  2. TPU requires that the last two block sizes be multiples of 8 and 128 but the last block was the vocab block and for llama 3 block size 512 it was size 63. This can be fixed by moving it to front.

The larger problem is that TPU is raising a notimplementederror from something. possibly the masked loads. I haven't investigated yet

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(dismissing my approval)

@BrendanGraham14
Copy link
Author

I see - thanks for trying it out. Will get my hands on a TPU to debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants