|
74 | 74 | UrlDataSource, |
75 | 75 | WrappedHFDataSource, |
76 | 76 | ) |
| 77 | +from levanter.data.ul2r import DenoisingConfig, Ul2rDataset # noqa |
77 | 78 | from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa |
78 | 79 | from levanter.store.cache import build_or_load_cache # noqa |
79 | 80 | from levanter.utils.jax_utils import key_iterator, use_cpu_device # noqa |
@@ -438,6 +439,13 @@ class SupervisedLmDatasetFormat(LmDatasetFormatBase): |
438 | 439 | pack: bool = True |
439 | 440 | mask_inputs: bool = True |
440 | 441 |
|
| 442 | +@LmDatasetFormatBase.register_subclass("ul2r") |
| 443 | +@dataclass(frozen=True) |
| 444 | +class Ul2rDatasetFormat(TextLmDatasetFormat): |
| 445 | + task_configs: Dict[str, DenoisingConfig] = field(default_factory=dict) |
| 446 | + task_probs: Dict[str, float] = field(default_factory=dict) |
| 447 | + rng_seed: int = 37 |
| 448 | + |
441 | 449 |
|
442 | 450 | @dataclass(frozen=True) |
443 | 451 | class LmDatasetSourceConfigBase(abc.ABC): |
@@ -606,7 +614,7 @@ def preprocessor_for_format( |
606 | 614 | format: LmDatasetFormatBase, tokenizer: HfTokenizer, *, enforce_eos: bool = True, enforce_bos: bool = True |
607 | 615 | ) -> BatchProcessor[dict, dict]: |
608 | 616 | match format: |
609 | | - case TextLmDatasetFormat(text_key=key): |
| 617 | + case TextLmDatasetFormat(text_key=key) | Ul2rDatasetFormat(text_key=key): |
610 | 618 | return BatchTokenizer(tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, text_field=key) |
611 | 619 | case ChatLmDatasetFormat(messages_field=m, single_turn=s_turn, chat_template=ct, mask_user_turns=mt): |
612 | 620 | if s_turn: |
@@ -640,6 +648,13 @@ def dataset_for_format( |
640 | 648 | return MultiturnChatDataset(cache, Pos, max_segments_per_example=64 if pack else 1, mask_user_turns=mask_user_turns) # type: ignore |
641 | 649 | case SupervisedLmDatasetFormat(pack=pack, mask_inputs=mask_inputs): |
642 | 650 | return SupervisedDataset(cache, Pos, max_segments_per_example=64 if pack else 1, mask_inputs=mask_inputs) # type: ignore |
| 651 | + case Ul2rDatasetFormat(task_configs=task_configs, task_probs=task_probs, rng_seed=rng_seed): |
| 652 | + key = jax.random.PRNGKey(rng_seed) |
| 653 | + # TODO Get actual pad_token_id. Currently we only use this in ul2r_loss_mask. |
| 654 | + pad_token_id = 0 |
| 655 | + max_segments_per_example = 64 |
| 656 | + slice_strategy = "left" |
| 657 | + return Ul2rDataset(cache, Pos, task_configs, task_probs, key, pad_token_id, max_segments_per_example, slice_strategy) |
643 | 658 | case _: |
644 | 659 | raise ValueError(f"Unknown format {format}") |
645 | 660 |
|
|
0 commit comments