Skip to content

Commit 8291ba5

Browse files
fused ce kernels: b,s,v -> v,b,s for tpu compat, pass through dtype
1 parent 5342233 commit 8291ba5

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

src/levanter/models/loss.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,18 @@ def _block_cross_entropy_forward_kernel(
314314
BatchFull: hax.Axis,
315315
Embed: hax.Axis,
316316
Label: hax.Axis,
317+
dtype: jnp.dtype,
317318
logit_soft_cap: Optional[float] = None,
318319
):
319-
# Get program IDs for all dimensions
320-
pid_batch = pl.program_id(0)
321-
pid_seq = pl.program_id(1)
322-
pid_vocab = pl.program_id(2)
320+
pid_vocab = pl.program_id(0)
321+
pid_batch = pl.program_id(1)
322+
pid_seq = pl.program_id(2)
323323

324324
vocab_start = pid_vocab * Vocab.size
325325

326326
batch_mask = _make_tile_mask(Batch, BatchFull, pid_batch)
327327
pos_mask = _make_tile_mask(Pos, PosFull, pid_seq)
328+
328329
vocab_mask = _make_tile_mask(Vocab, Label, pid_vocab)
329330
batch_pos_mask = batch_mask.broadcast_axis((Batch, Pos)) * pos_mask.broadcast_axis((Batch, Pos))
330331

@@ -337,15 +338,17 @@ def _block_cross_entropy_forward_kernel(
337338
),
338339
axes=(Batch, Pos, Embed),
339340
)
341+
340342
lm_head = hax.NamedArray(
341343
array=pl.load(
342344
lm_head_ref,
343345
...,
344-
mask=vocab_mask.array,
346+
mask=vocab_mask.array[..., None],
345347
other=0,
346348
),
347-
axes=(Embed, Vocab),
349+
axes=(Vocab, Embed),
348350
)
351+
349352
labels = hax.NamedArray(
350353
array=pl.load(
351354
labels_ref,
@@ -363,7 +366,7 @@ def _block_cross_entropy_forward_kernel(
363366
# Compute max only over valid vocab columns
364367
masked_for_max = hax.NamedArray(array=jnp.where(vocab_mask.array, logits.array, -jnp.inf), axes=logits.axes)
365368
max_logit = hax.max(masked_for_max, axis=Vocab)
366-
targets = _block_to_one_hot(labels, Vocab, vocab_start, logits.dtype) * pos_mask * batch_mask
369+
targets = _block_to_one_hot(labels, Vocab, vocab_start, dtype) * pos_mask * batch_mask
367370

368371
# Mask out logits which aren't in the block. Must happen after max_logit but before dot.
369372
logits = logits * vocab_mask * pos_mask * batch_mask
@@ -422,7 +425,6 @@ def _block_cross_entropy_forward(
422425
num_vocab_blocks = math.ceil(Label.size / vocab_block_size)
423426

424427
pred_embeddings, lm_head = pred
425-
lm_head = hax.rearrange(lm_head, (Contract, Label))
426428
Batch = pred_embeddings.axes[0]
427429

428430
if batch_block_size is None:
@@ -450,33 +452,34 @@ def _block_cross_entropy_forward(
450452
Vocab=VocabSlice,
451453
Embed=Contract,
452454
Label=Label,
455+
dtype=dtype,
453456
),
454457
out_shape=[
455458
jax.ShapeDtypeStruct((Batch.size, Pos.size, VocabBlock.size), dtype=dtype), # dot
456459
jax.ShapeDtypeStruct((Batch.size, Pos.size, VocabBlock.size), dtype=dtype), # max_logit
457460
jax.ShapeDtypeStruct((Batch.size, Pos.size, VocabBlock.size), dtype=dtype), # logsumexp
458461
],
459-
grid=(num_batch_blocks, num_seq_blocks, num_vocab_blocks),
462+
grid=(num_vocab_blocks, num_batch_blocks, num_seq_blocks),
460463
in_specs=[
461-
pl.BlockSpec([Contract.size, VocabSlice.size], index_map=lambda b, s, v: (0, v)), # lm_head
464+
pl.BlockSpec([VocabSlice.size, Contract.size], index_map=lambda v, b, s: (v, 0)), # lm_head
462465
pl.BlockSpec(
463466
[BatchSlice.size, PosSlice.size, Contract.size],
464-
index_map=lambda b, s, v: (b, s, 0),
467+
index_map=lambda v, b, s: (b, s, 0),
465468
), # embeddings
466-
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda b, s, v: (b, s)), # labels
469+
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda v, b, s: (b, s)), # labels
467470
],
468471
out_specs=[
469472
pl.BlockSpec(
470473
[BatchSlice.size, PosSlice.size, 1],
471-
index_map=lambda b, s, v: (b, s, v),
474+
index_map=lambda v, b, s: (b, s, v),
472475
), # dot
473476
pl.BlockSpec(
474477
[BatchSlice.size, PosSlice.size, 1],
475-
index_map=lambda b, s, v: (b, s, v),
478+
index_map=lambda v, b, s: (b, s, v),
476479
), # max_logit
477480
pl.BlockSpec(
478481
[BatchSlice.size, PosSlice.size, 1],
479-
index_map=lambda b, s, v: (b, s, v),
482+
index_map=lambda v, b, s: (b, s, v),
480483
), # logsumexp
481484
],
482485
interpret=use_interpret,
@@ -490,6 +493,7 @@ def _block_cross_entropy_forward(
490493
logsumexp = max_logit + hax.log(hax.sum(hax.exp(block_logsumexps + block_max_logits - max_logit), axis=VocabBlock))
491494
dot = hax.sum(block_dots, axis=VocabBlock)
492495
loss = logsumexp - dot
496+
493497
return (loss, logsumexp), (logsumexp,)
494498

495499

@@ -511,13 +515,14 @@ def _block_cross_entropy_backward_kernel(
511515
Vocab: hax.Axis,
512516
Embed: hax.Axis,
513517
Label: hax.Axis,
518+
dtype: jnp.dtype,
514519
):
515520
"""
516521
Pallas kernel for computing gradients in block-wise cross-entropy loss.
517522
"""
518-
pid_batch = pl.program_id(0)
519-
pid_seq = pl.program_id(1)
520-
pid_vocab = pl.program_id(2)
523+
pid_vocab = pl.program_id(0)
524+
pid_batch = pl.program_id(1)
525+
pid_seq = pl.program_id(2)
521526
vocab_start = pid_vocab * Vocab.size
522527

523528
batch_mask = _make_tile_mask(Batch, BatchFull, pid_batch)
@@ -526,8 +531,8 @@ def _block_cross_entropy_backward_kernel(
526531
batch_pos_mask = batch_mask.broadcast_axis((Batch, Pos)) * pos_mask.broadcast_axis((Batch, Pos))
527532

528533
lm_head_block = hax.NamedArray(
529-
array=pl.load(lm_head_ref, ..., mask=vocab_mask.array, other=0),
530-
axes=(Embed, Vocab),
534+
array=pl.load(lm_head_ref, ..., mask=vocab_mask.array[..., None], other=0),
535+
axes=(Vocab, Embed),
531536
)
532537
embeddings = hax.NamedArray(
533538
array=pl.load(pred_embeddings_ref, ..., mask=batch_pos_mask.array[..., None], other=0),
@@ -556,7 +561,7 @@ def _block_cross_entropy_backward_kernel(
556561

557562
probs = hax.exp(logits - log_z) * vocab_mask
558563

559-
targets = _block_to_one_hot(labels, Vocab, vocab_start, logits.dtype) * pos_mask * batch_mask
564+
targets = _block_to_one_hot(labels, Vocab, vocab_start, dtype) * pos_mask * batch_mask
560565

561566
grad_logits = grad_loss * (probs - targets) + grad_log_z * probs # [Batch, Pos, Vocab]
562567
grad_logits = grad_logits * vocab_mask
@@ -567,7 +572,7 @@ def _block_cross_entropy_backward_kernel(
567572
grad_logits = grad_logits * pos_mask * batch_mask
568573

569574
grad_embeddings_block = hax.dot(grad_logits, lm_head_block, axis=Vocab) # [Batch, Pos, Embed]
570-
grad_lm_head_block = hax.sum(hax.dot(embeddings, grad_logits, axis=Pos), axis=Batch) # [Embed, Vocab]
575+
grad_lm_head_block = hax.sum(hax.dot(grad_logits, embeddings, axis=Pos), axis=Batch) # [Vocab, Embed]
571576

572577
pl.store(grad_embeddings_ref, ..., grad_embeddings_block.array[..., None]) # last dim is Block=1 slice
573578
pl.store(grad_lm_head_ref, ..., grad_lm_head_block.array[None, None, ...])
@@ -616,10 +621,7 @@ def _block_cross_entropy_backward(
616621
vocab_block_size = Label.size
617622

618623
num_vocab_blocks = math.ceil(Label.size / vocab_block_size)
619-
pred_embeddings, lm_head_orig = pred
620-
621-
lm_head_orig_axes = lm_head_orig.axes
622-
lm_head = hax.rearrange(lm_head_orig, (Contract, Label))
624+
pred_embeddings, lm_head = pred
623625

624626
VocabSlice = Label.resize(vocab_block_size)
625627
VocabBlock = Label.resize(num_vocab_blocks)
@@ -645,7 +647,7 @@ def _block_cross_entropy_backward(
645647
grad_log_z = hax.zeros((Batch, Pos), dtype=pred_embeddings.dtype)
646648

647649
grad_embedding_out_shape = (Batch, Pos, Contract, VocabBlock)
648-
grad_lm_head_out_shape = (BatchBlock, PosBlock, Contract, Label)
650+
grad_lm_head_out_shape = (BatchBlock, PosBlock, Label, Contract)
649651

650652
grad_embeddings_blocks, grad_lm_head_blocks = pl.pallas_call(
651653
functools.partial(
@@ -658,33 +660,34 @@ def _block_cross_entropy_backward(
658660
Vocab=VocabSlice,
659661
Embed=Contract,
660662
Label=Label,
663+
dtype=dtype,
661664
),
662665
out_shape=[
663666
# grad_embeddings - aggregated over vocab
664667
jax.ShapeDtypeStruct([ax.size for ax in grad_embedding_out_shape], dtype=pred_embeddings.dtype),
665668
# grad_lm_head - aggregated over batch and pos
666669
jax.ShapeDtypeStruct([ax.size for ax in grad_lm_head_out_shape], dtype=lm_head.dtype),
667670
],
668-
grid=(num_batch_blocks, num_pos_blocks, num_vocab_blocks),
671+
grid=(num_vocab_blocks, num_batch_blocks, num_pos_blocks),
669672
in_specs=[
670-
pl.BlockSpec([Contract.size, VocabSlice.size], index_map=lambda b, s, v: (0, v)), # lm_head
673+
pl.BlockSpec([VocabSlice.size, Contract.size], index_map=lambda v, b, s: (v, 0)), # lm_head
671674
pl.BlockSpec(
672675
[BatchSlice.size, PosSlice.size, Contract.size],
673-
index_map=lambda b, s, v: (b, s, 0),
676+
index_map=lambda v, b, s: (b, s, 0),
674677
), # embeddings
675-
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda b, s, v: (b, s)), # labels
676-
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda b, s, v: (b, s)), # log_z
677-
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda b, s, v: (b, s)), # grad_loss
678-
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda b, s, v: (b, s)), # grad_log_z
678+
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda v, b, s: (b, s)), # labels
679+
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda v, b, s: (b, s)), # log_z
680+
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda v, b, s: (b, s)), # grad_loss
681+
pl.BlockSpec([BatchSlice.size, PosSlice.size], index_map=lambda v, b, s: (b, s)), # grad_log_z
679682
],
680683
out_specs=[
681684
pl.BlockSpec(
682685
[BatchSlice.size, PosSlice.size, Contract.size, 1],
683-
index_map=lambda b, s, v: (b, s, 0, v),
686+
index_map=lambda v, b, s: (b, s, 0, v),
684687
), # grad_embeddings - aggregated over vocab
685688
pl.BlockSpec(
686-
[1, 1, Contract.size, VocabSlice.size],
687-
index_map=lambda b, s, v: (b, s, 0, v),
689+
[1, 1, VocabSlice.size, Contract.size],
690+
index_map=lambda v, b, s: (b, s, v, 0),
688691
), # grad_lm_head - aggregated over batch and pos
689692
],
690693
interpret=use_interpret,
@@ -703,7 +706,6 @@ def _block_cross_entropy_backward(
703706
grad_lm_head = hax.NamedArray(array=grad_lm_head_blocks, axes=grad_lm_head_out_shape)
704707
grad_lm_head = hax.sum(grad_lm_head, axis=(BatchBlock, PosBlock))
705708

706-
grad_lm_head = hax.rearrange(grad_lm_head, lm_head_orig_axes)
707709
return (grad_embeddings, grad_lm_head)
708710

709711

0 commit comments

Comments
 (0)