Skip to content

Commit 0f4dbfc

Browse files
committed
wip: try make config run with python -m levanter.main.cache_dataset --config config/llama3_ul2r.yaml
1 parent a6cfe1c commit 0f4dbfc

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

config/llama3_ul2r.yaml

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Based on llama2_7b_continued.yaml.
2+
# TODO right now this is a `RayCachedLMDatasetConfig` for `cache_dataset.py`,
3+
# not `TrainLmConfig` for `train_lm.py`; change it to the latter.
4+
5+
# data:
6+
id: dlwh/wikitext_103_detokenized
7+
tokenizer: meta-llama/Llama-3.1-8B
8+
format:
9+
type: ul2r
10+
text_key: text
11+
task_configs:
12+
r:
13+
mask_prob: 0.15
14+
mean_span_length: 3.0
15+
random_roll: true
16+
task_token: "<|reserved_special_token_3|>"
17+
x1:
18+
mask_prob: 0.15
19+
mean_span_length: 32.0
20+
random_roll: true
21+
task_token: "<|reserved_special_token_4|>"
22+
x2:
23+
mask_prob: 0.5
24+
mean_span_length: 3.0
25+
random_roll: true
26+
task_token: "<|reserved_special_token_4|>"
27+
s:
28+
task_token: "<|reserved_special_token_5|>"
29+
task_probs:
30+
r: 0.5
31+
x1: 0.125
32+
x2: 0.125
33+
s: 0.25
34+
rng_seed: 42
35+
36+
# TODO haven't tested any of the model stuff yet
37+
38+
# model:
39+
# type: llama
40+
# initialize_from_hf: true
41+
# use_hf_model_config: true
42+
# model_name_or_path: meta-llama/Llama-3.1-8B
43+
44+
# trainer:
45+
# tracker:
46+
# type: wandb
47+
# project: "levanter"
48+
# tags: ["ul2r", "llama3", "wikitext"]
49+
50+
# mp: p=f32,c=bfloat16
51+
52+
# model_axis_size: 1
53+
# per_device_eval_parallelism: 4
54+
55+
# train_batch_size: 1024
56+
# num_train_steps: 10000
57+
# steps_per_eval: 500
58+
59+
# optimizer:
60+
# learning_rate: 1.2e-4
61+
# weight_decay: 0.0

src/levanter/data/text.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
UrlDataSource,
7575
WrappedHFDataSource,
7676
)
77+
from levanter.data.ul2r import DenoisingConfig, Ul2rDataset # noqa
7778
from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa
7879
from levanter.store.cache import build_or_load_cache # noqa
7980
from levanter.utils.jax_utils import key_iterator, use_cpu_device # noqa
@@ -438,6 +439,13 @@ class SupervisedLmDatasetFormat(LmDatasetFormatBase):
438439
pack: bool = True
439440
mask_inputs: bool = True
440441

442+
@LmDatasetFormatBase.register_subclass("ul2r")
443+
@dataclass(frozen=True)
444+
class Ul2rDatasetFormat(TextLmDatasetFormat):
445+
task_configs: Dict[str, DenoisingConfig] = field(default_factory=dict)
446+
task_probs: Dict[str, float] = field(default_factory=dict)
447+
rng_seed: int = 37
448+
441449

442450
@dataclass(frozen=True)
443451
class LmDatasetSourceConfigBase(abc.ABC):
@@ -606,7 +614,7 @@ def preprocessor_for_format(
606614
format: LmDatasetFormatBase, tokenizer: HfTokenizer, *, enforce_eos: bool = True, enforce_bos: bool = True
607615
) -> BatchProcessor[dict, dict]:
608616
match format:
609-
case TextLmDatasetFormat(text_key=key):
617+
case TextLmDatasetFormat(text_key=key) | Ul2rDatasetFormat(text_key=key):
610618
return BatchTokenizer(tokenizer, enforce_bos=enforce_bos, enforce_eos=enforce_eos, text_field=key)
611619
case ChatLmDatasetFormat(messages_field=m, single_turn=s_turn, chat_template=ct, mask_user_turns=mt):
612620
if s_turn:
@@ -640,6 +648,13 @@ def dataset_for_format(
640648
return MultiturnChatDataset(cache, Pos, max_segments_per_example=64 if pack else 1, mask_user_turns=mask_user_turns) # type: ignore
641649
case SupervisedLmDatasetFormat(pack=pack, mask_inputs=mask_inputs):
642650
return SupervisedDataset(cache, Pos, max_segments_per_example=64 if pack else 1, mask_inputs=mask_inputs) # type: ignore
651+
case Ul2rDatasetFormat(task_configs=task_configs, task_probs=task_probs, rng_seed=rng_seed):
652+
key = jax.random.PRNGKey(rng_seed)
653+
# TODO Get actual pad_token_id. Currently we only use this in ul2r_loss_mask.
654+
pad_token_id = 0
655+
max_segments_per_example = 64
656+
slice_strategy = "left"
657+
return Ul2rDataset(cache, Pos, task_configs, task_probs, key, pad_token_id, max_segments_per_example, slice_strategy)
643658
case _:
644659
raise ValueError(f"Unknown format {format}")
645660

src/levanter/data/ul2r.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,8 @@ def _compute_length(task_idx: jnp.ndarray, length: jnp.ndarray) -> int:
892892
max_segments_per_example=max_segments_per_example,
893893
slice_strategy=slice_strategy,
894894
packing_lengths=out_lengths,
895+
# Reserve space for UL2R; denoising examples increase in length.
896+
pad_with_zeroes=True,
895897
)
896898
self.Pos = Pos
897899
self.pad_token_id = pad_token_id

0 commit comments

Comments
 (0)