@@ -314,17 +314,18 @@ def _block_cross_entropy_forward_kernel(
314314 BatchFull : hax .Axis ,
315315 Embed : hax .Axis ,
316316 Label : hax .Axis ,
317+ dtype : jnp .dtype ,
317318 logit_soft_cap : Optional [float ] = None ,
318319):
319- # Get program IDs for all dimensions
320- pid_batch = pl .program_id (0 )
321- pid_seq = pl .program_id (1 )
322- pid_vocab = pl .program_id (2 )
320+ pid_vocab = pl .program_id (0 )
321+ pid_batch = pl .program_id (1 )
322+ pid_seq = pl .program_id (2 )
323323
324324 vocab_start = pid_vocab * Vocab .size
325325
326326 batch_mask = _make_tile_mask (Batch , BatchFull , pid_batch )
327327 pos_mask = _make_tile_mask (Pos , PosFull , pid_seq )
328+
328329 vocab_mask = _make_tile_mask (Vocab , Label , pid_vocab )
329330 batch_pos_mask = batch_mask .broadcast_axis ((Batch , Pos )) * pos_mask .broadcast_axis ((Batch , Pos ))
330331
@@ -337,15 +338,17 @@ def _block_cross_entropy_forward_kernel(
337338 ),
338339 axes = (Batch , Pos , Embed ),
339340 )
341+
340342 lm_head = hax .NamedArray (
341343 array = pl .load (
342344 lm_head_ref ,
343345 ...,
344- mask = vocab_mask .array ,
346+ mask = vocab_mask .array [..., None ] ,
345347 other = 0 ,
346348 ),
347- axes = (Embed , Vocab ),
349+ axes = (Vocab , Embed ),
348350 )
351+
349352 labels = hax .NamedArray (
350353 array = pl .load (
351354 labels_ref ,
@@ -363,7 +366,7 @@ def _block_cross_entropy_forward_kernel(
363366 # Compute max only over valid vocab columns
364367 masked_for_max = hax .NamedArray (array = jnp .where (vocab_mask .array , logits .array , - jnp .inf ), axes = logits .axes )
365368 max_logit = hax .max (masked_for_max , axis = Vocab )
366- targets = _block_to_one_hot (labels , Vocab , vocab_start , logits . dtype ) * pos_mask * batch_mask
369+ targets = _block_to_one_hot (labels , Vocab , vocab_start , dtype ) * pos_mask * batch_mask
367370
368371 # Mask out logits which aren't in the block. Must happen after max_logit but before dot.
369372 logits = logits * vocab_mask * pos_mask * batch_mask
@@ -422,7 +425,6 @@ def _block_cross_entropy_forward(
422425 num_vocab_blocks = math .ceil (Label .size / vocab_block_size )
423426
424427 pred_embeddings , lm_head = pred
425- lm_head = hax .rearrange (lm_head , (Contract , Label ))
426428 Batch = pred_embeddings .axes [0 ]
427429
428430 if batch_block_size is None :
@@ -450,33 +452,34 @@ def _block_cross_entropy_forward(
450452 Vocab = VocabSlice ,
451453 Embed = Contract ,
452454 Label = Label ,
455+ dtype = dtype ,
453456 ),
454457 out_shape = [
455458 jax .ShapeDtypeStruct ((Batch .size , Pos .size , VocabBlock .size ), dtype = dtype ), # dot
456459 jax .ShapeDtypeStruct ((Batch .size , Pos .size , VocabBlock .size ), dtype = dtype ), # max_logit
457460 jax .ShapeDtypeStruct ((Batch .size , Pos .size , VocabBlock .size ), dtype = dtype ), # logsumexp
458461 ],
459- grid = (num_batch_blocks , num_seq_blocks , num_vocab_blocks ),
462+ grid = (num_vocab_blocks , num_batch_blocks , num_seq_blocks ),
460463 in_specs = [
461- pl .BlockSpec ([Contract .size , VocabSlice .size ], index_map = lambda b , s , v : (0 , v )), # lm_head
464+ pl .BlockSpec ([VocabSlice .size , Contract .size ], index_map = lambda v , b , s : (v , 0 )), # lm_head
462465 pl .BlockSpec (
463466 [BatchSlice .size , PosSlice .size , Contract .size ],
464- index_map = lambda b , s , v : (b , s , 0 ),
467+ index_map = lambda v , b , s : (b , s , 0 ),
465468 ), # embeddings
466- pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda b , s , v : (b , s )), # labels
469+ pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda v , b , s : (b , s )), # labels
467470 ],
468471 out_specs = [
469472 pl .BlockSpec (
470473 [BatchSlice .size , PosSlice .size , 1 ],
471- index_map = lambda b , s , v : (b , s , v ),
474+ index_map = lambda v , b , s : (b , s , v ),
472475 ), # dot
473476 pl .BlockSpec (
474477 [BatchSlice .size , PosSlice .size , 1 ],
475- index_map = lambda b , s , v : (b , s , v ),
478+ index_map = lambda v , b , s : (b , s , v ),
476479 ), # max_logit
477480 pl .BlockSpec (
478481 [BatchSlice .size , PosSlice .size , 1 ],
479- index_map = lambda b , s , v : (b , s , v ),
482+ index_map = lambda v , b , s : (b , s , v ),
480483 ), # logsumexp
481484 ],
482485 interpret = use_interpret ,
@@ -490,6 +493,7 @@ def _block_cross_entropy_forward(
490493 logsumexp = max_logit + hax .log (hax .sum (hax .exp (block_logsumexps + block_max_logits - max_logit ), axis = VocabBlock ))
491494 dot = hax .sum (block_dots , axis = VocabBlock )
492495 loss = logsumexp - dot
496+
493497 return (loss , logsumexp ), (logsumexp ,)
494498
495499
@@ -511,13 +515,14 @@ def _block_cross_entropy_backward_kernel(
511515 Vocab : hax .Axis ,
512516 Embed : hax .Axis ,
513517 Label : hax .Axis ,
518+ dtype : jnp .dtype ,
514519):
515520 """
516521 Pallas kernel for computing gradients in block-wise cross-entropy loss.
517522 """
518- pid_batch = pl .program_id (0 )
519- pid_seq = pl .program_id (1 )
520- pid_vocab = pl .program_id (2 )
523+ pid_vocab = pl .program_id (0 )
524+ pid_batch = pl .program_id (1 )
525+ pid_seq = pl .program_id (2 )
521526 vocab_start = pid_vocab * Vocab .size
522527
523528 batch_mask = _make_tile_mask (Batch , BatchFull , pid_batch )
@@ -526,8 +531,8 @@ def _block_cross_entropy_backward_kernel(
526531 batch_pos_mask = batch_mask .broadcast_axis ((Batch , Pos )) * pos_mask .broadcast_axis ((Batch , Pos ))
527532
528533 lm_head_block = hax .NamedArray (
529- array = pl .load (lm_head_ref , ..., mask = vocab_mask .array , other = 0 ),
530- axes = (Embed , Vocab ),
534+ array = pl .load (lm_head_ref , ..., mask = vocab_mask .array [..., None ] , other = 0 ),
535+ axes = (Vocab , Embed ),
531536 )
532537 embeddings = hax .NamedArray (
533538 array = pl .load (pred_embeddings_ref , ..., mask = batch_pos_mask .array [..., None ], other = 0 ),
@@ -556,7 +561,7 @@ def _block_cross_entropy_backward_kernel(
556561
557562 probs = hax .exp (logits - log_z ) * vocab_mask
558563
559- targets = _block_to_one_hot (labels , Vocab , vocab_start , logits . dtype ) * pos_mask * batch_mask
564+ targets = _block_to_one_hot (labels , Vocab , vocab_start , dtype ) * pos_mask * batch_mask
560565
561566 grad_logits = grad_loss * (probs - targets ) + grad_log_z * probs # [Batch, Pos, Vocab]
562567 grad_logits = grad_logits * vocab_mask
@@ -567,7 +572,7 @@ def _block_cross_entropy_backward_kernel(
567572 grad_logits = grad_logits * pos_mask * batch_mask
568573
569574 grad_embeddings_block = hax .dot (grad_logits , lm_head_block , axis = Vocab ) # [Batch, Pos, Embed]
570- grad_lm_head_block = hax .sum (hax .dot (embeddings , grad_logits , axis = Pos ), axis = Batch ) # [Embed, Vocab ]
575+ grad_lm_head_block = hax .sum (hax .dot (grad_logits , embeddings , axis = Pos ), axis = Batch ) # [Vocab, Embed ]
571576
572577 pl .store (grad_embeddings_ref , ..., grad_embeddings_block .array [..., None ]) # last dim is Block=1 slice
573578 pl .store (grad_lm_head_ref , ..., grad_lm_head_block .array [None , None , ...])
@@ -616,10 +621,7 @@ def _block_cross_entropy_backward(
616621 vocab_block_size = Label .size
617622
618623 num_vocab_blocks = math .ceil (Label .size / vocab_block_size )
619- pred_embeddings , lm_head_orig = pred
620-
621- lm_head_orig_axes = lm_head_orig .axes
622- lm_head = hax .rearrange (lm_head_orig , (Contract , Label ))
624+ pred_embeddings , lm_head = pred
623625
624626 VocabSlice = Label .resize (vocab_block_size )
625627 VocabBlock = Label .resize (num_vocab_blocks )
@@ -645,7 +647,7 @@ def _block_cross_entropy_backward(
645647 grad_log_z = hax .zeros ((Batch , Pos ), dtype = pred_embeddings .dtype )
646648
647649 grad_embedding_out_shape = (Batch , Pos , Contract , VocabBlock )
648- grad_lm_head_out_shape = (BatchBlock , PosBlock , Contract , Label )
650+ grad_lm_head_out_shape = (BatchBlock , PosBlock , Label , Contract )
649651
650652 grad_embeddings_blocks , grad_lm_head_blocks = pl .pallas_call (
651653 functools .partial (
@@ -658,33 +660,34 @@ def _block_cross_entropy_backward(
658660 Vocab = VocabSlice ,
659661 Embed = Contract ,
660662 Label = Label ,
663+ dtype = dtype ,
661664 ),
662665 out_shape = [
663666 # grad_embeddings - aggregated over vocab
664667 jax .ShapeDtypeStruct ([ax .size for ax in grad_embedding_out_shape ], dtype = pred_embeddings .dtype ),
665668 # grad_lm_head - aggregated over batch and pos
666669 jax .ShapeDtypeStruct ([ax .size for ax in grad_lm_head_out_shape ], dtype = lm_head .dtype ),
667670 ],
668- grid = (num_batch_blocks , num_pos_blocks , num_vocab_blocks ),
671+ grid = (num_vocab_blocks , num_batch_blocks , num_pos_blocks ),
669672 in_specs = [
670- pl .BlockSpec ([Contract .size , VocabSlice .size ], index_map = lambda b , s , v : (0 , v )), # lm_head
673+ pl .BlockSpec ([VocabSlice .size , Contract .size ], index_map = lambda v , b , s : (v , 0 )), # lm_head
671674 pl .BlockSpec (
672675 [BatchSlice .size , PosSlice .size , Contract .size ],
673- index_map = lambda b , s , v : (b , s , 0 ),
676+ index_map = lambda v , b , s : (b , s , 0 ),
674677 ), # embeddings
675- pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda b , s , v : (b , s )), # labels
676- pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda b , s , v : (b , s )), # log_z
677- pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda b , s , v : (b , s )), # grad_loss
678- pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda b , s , v : (b , s )), # grad_log_z
678+ pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda v , b , s : (b , s )), # labels
679+ pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda v , b , s : (b , s )), # log_z
680+ pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda v , b , s : (b , s )), # grad_loss
681+ pl .BlockSpec ([BatchSlice .size , PosSlice .size ], index_map = lambda v , b , s : (b , s )), # grad_log_z
679682 ],
680683 out_specs = [
681684 pl .BlockSpec (
682685 [BatchSlice .size , PosSlice .size , Contract .size , 1 ],
683- index_map = lambda b , s , v : (b , s , 0 , v ),
686+ index_map = lambda v , b , s : (b , s , 0 , v ),
684687 ), # grad_embeddings - aggregated over vocab
685688 pl .BlockSpec (
686- [1 , 1 , Contract .size , VocabSlice .size ],
687- index_map = lambda b , s , v : (b , s , 0 , v ),
689+ [1 , 1 , VocabSlice .size , Contract .size ],
690+ index_map = lambda v , b , s : (b , s , v , 0 ),
688691 ), # grad_lm_head - aggregated over batch and pos
689692 ],
690693 interpret = use_interpret ,
@@ -703,7 +706,6 @@ def _block_cross_entropy_backward(
703706 grad_lm_head = hax .NamedArray (array = grad_lm_head_blocks , axes = grad_lm_head_out_shape )
704707 grad_lm_head = hax .sum (grad_lm_head , axis = (BatchBlock , PosBlock ))
705708
706- grad_lm_head = hax .rearrange (grad_lm_head , lm_head_orig_axes )
707709 return (grad_embeddings , grad_lm_head )
708710
709711
0 commit comments