Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
61f7b30
WIP
jyc Sep 12, 2025
3904d53
provide missing KPos argument
jyc Oct 20, 2025
4fb7ed2
wip config
jyc Oct 20, 2025
1d044d9
consistent w/ text
jyc Oct 21, 2025
5d7e141
fix creating 0-length segments (perm < threshold, not perm_masked < t…
jyc Oct 23, 2025
7f51693
fix random roll; was creating all false noise masks
jyc Oct 24, 2025
f70b68d
disable random roll for now; not in UL2R or UL2 papers
jyc Oct 28, 2025
f495121
fix inputs_len not accounting for task token
jyc Oct 28, 2025
6f61075
set segment_ids to -1 for empty positions, not -1
jyc Oct 29, 2025
ddfe6b8
random_spans_noise_mask w/o random_roll is deterministic for small in…
jyc Oct 29, 2025
2260a39
make random_roll work by truncating & handling all non-noise case
jyc Oct 30, 2025
6c8fa0b
add disable_jit option for debugging
jyc Oct 30, 2025
cc4a385
put task token / sentinel IDs in config; using out-of-bounds token ID…
jyc Oct 31, 2025
1f97c6a
correctly handle noise_masks that start with 0 for the target; should…
jyc Oct 31, 2025
40bf05a
fix task_kind
jyc Oct 31, 2025
afb8837
don't conflict w/ gpt2 bos/eos
jyc Oct 31, 2025
346234e
back to lightning... add tokens to gpt2 tokenizer
jyc Nov 3, 2025
add044f
use new tokens
jyc Nov 4, 2025
4d38c8a
keep trying to debug NaN loss...
jyc Nov 4, 2025
ef9bd1b
fix test, force_initial_sentinel
jyc Nov 4, 2025
6b973f5
more test fixes
jyc Nov 4, 2025
3967cbb
:'( where are the nans coming from
jyc Nov 5, 2025
564d56b
Ugh
jyc Nov 5, 2025
83cb5b9
vocab size
jyc Nov 5, 2025
78b9ddd
wat
jyc Nov 5, 2025
a77fec6
:(
jyc Nov 5, 2025
a2819fb
???
jyc Nov 5, 2025
0f7c870
oops
jyc Nov 5, 2025
6b1366a
wat
jyc Nov 5, 2025
5183479
hmm is it because we stopping too early? but not sure where 50399 com…
jyc Nov 5, 2025
45c2c6e
length is upper bound, not count
jyc Nov 13, 2025
f92e701
reserve space for random_roll creating extra noise span
jyc Nov 13, 2025
ceb7656
store random_roll in task params & actually use
jyc Nov 13, 2025
8d985ad
ok can actually do run now
jyc Nov 13, 2025
d71d5d9
disable prints because no more nans
jyc Nov 13, 2025
4b76785
disable moar prints
jyc Nov 13, 2025
c2eed36
moar
jyc Nov 13, 2025
5ff6e3c
increase num train steps
jyc Nov 13, 2025
84fb835
try do larger training run
jyc Nov 15, 2025
09b7035
doesn't like no options
jyc Nov 15, 2025
b908d31
trying to get inference to work
jyc Nov 15, 2025
012cc17
oops
jyc Nov 15, 2025
78c4c6a
llama
jyc Nov 15, 2025
ad73363
oops
jyc Nov 15, 2025
233f0ab
agh
jyc Nov 15, 2025
7ca8748
pile database has loading script?
jyc Nov 15, 2025
7537706
got my own gpu 8-)
jyc Nov 15, 2025
ead9c35
relative path cache bug
jyc Nov 15, 2025
79a5fa2
paths
jyc Nov 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions config/gpt2_nano_ul2r.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#data:
# id: dlwh/wikitext_103_detokenized

#disable_jit: true
#data_only: true

model:
type: gpt2
# hidden_dim: 32
# num_heads: 4
# num_layers: 2

initialize_from_hf: openai-community/gpt2
use_hf_model_config: true

trainer:
# checkpointer:
# base_path: "checkpoints"
# keep:
# - every: 50
# save_interval: 5m

# per_device_parallelism: -1
batch_axis: "batch"
fsdp_axis: "embed"
mp: f32
num_train_steps: 100000
tensor_parallel_axes: ["mlp", "heads"]
train_batch_size: 32 # what's good?

data:
id: dlwh/wikitext_103_detokenized
# id: openai/gsm8k # This is super tiny. It's just for smoke tests.
# name: socratic
cache_dir: /tmp/marin_cache
tokenizer: ./gpt2_ul2r_DEV
format:
type: ul2r
#text_key: text # only for gsm8k
#text_key: question
task_configs:
r:
type: rx
mask_prob: 0.15
mean_span_length: 3.0
random_roll: true
task_token_id: 50257
x1:
type: rx
mask_prob: 0.15
mean_span_length: 32.0
random_roll: true
task_token_id: 50258
x2:
type: rx
mask_prob: 0.5
mean_span_length: 3.0
random_roll: true
task_token_id: 50258
s:
type: s
task_token_id: 50259
task_probs:
r: 0.5
x1: 0.125
x2: 0.125
s: 0.25
rng_seed: 42
sentinel_token_id_start: 50260
sentinel_token_id_count: 100
70 changes: 70 additions & 0 deletions config/llama2_7b_ul2r.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
data:
tokenizer: ./llama_2_7b_hf_ul2r/
configs:
pile:
id: dlwh/wikitext_103_detokenized
cache_dir: /tmp/marin_cache_pile
ul2r:
id: dlwh/wikitext_103_detokenized
cache_dir: /tmp/marin_cache_ul2r
format:
type: ul2r
task_configs:
r:
type: rx
mask_prob: 0.15
mean_span_length: 3.0
random_roll: true
task_token_id: 32000
x1:
type: rx
mask_prob: 0.15
mean_span_length: 32.0
random_roll: true
task_token_id: 32001
x2:
type: rx
mask_prob: 0.5
mean_span_length: 3.0
random_roll: true
task_token_id: 32001
s:
type: s
task_token_id: 32002
task_probs:
r: 0.5
x1: 0.125
x2: 0.125
s: 0.25
rng_seed: 42
sentinel_token_id_start: 32003
sentinel_token_id_count: 100
train_weights:
pile: 0 # only use Pile for evaluation; UL2R says loss should decrease
ul2r: 1

model:
type: llama
initialize_from_hf: true
use_hf_model_config: true

# From config/llam2_7b_continued.yaml

trainer:
tracker:
type: wandb
project: "levanter"
tags: ["ul2r", "llama2"]

mp: p=f32,c=bfloat16

model_axis_size: 1
per_device_eval_parallelism: 4

train_batch_size: 1024
num_train_steps: 10000
steps_per_eval: 500

optimizer:
learning_rate: 1.2e-4
weight_decay: 0.0
105 changes: 105 additions & 0 deletions gpt2_ul2r_DEV/added_tokens.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"<|mask_0|>": 50260,
"<|mask_10|>": 50270,
"<|mask_11|>": 50271,
"<|mask_12|>": 50272,
"<|mask_13|>": 50273,
"<|mask_14|>": 50274,
"<|mask_15|>": 50275,
"<|mask_16|>": 50276,
"<|mask_17|>": 50277,
"<|mask_18|>": 50278,
"<|mask_19|>": 50279,
"<|mask_1|>": 50261,
"<|mask_20|>": 50280,
"<|mask_21|>": 50281,
"<|mask_22|>": 50282,
"<|mask_23|>": 50283,
"<|mask_24|>": 50284,
"<|mask_25|>": 50285,
"<|mask_26|>": 50286,
"<|mask_27|>": 50287,
"<|mask_28|>": 50288,
"<|mask_29|>": 50289,
"<|mask_2|>": 50262,
"<|mask_30|>": 50290,
"<|mask_31|>": 50291,
"<|mask_32|>": 50292,
"<|mask_33|>": 50293,
"<|mask_34|>": 50294,
"<|mask_35|>": 50295,
"<|mask_36|>": 50296,
"<|mask_37|>": 50297,
"<|mask_38|>": 50298,
"<|mask_39|>": 50299,
"<|mask_3|>": 50263,
"<|mask_40|>": 50300,
"<|mask_41|>": 50301,
"<|mask_42|>": 50302,
"<|mask_43|>": 50303,
"<|mask_44|>": 50304,
"<|mask_45|>": 50305,
"<|mask_46|>": 50306,
"<|mask_47|>": 50307,
"<|mask_48|>": 50308,
"<|mask_49|>": 50309,
"<|mask_4|>": 50264,
"<|mask_50|>": 50310,
"<|mask_51|>": 50311,
"<|mask_52|>": 50312,
"<|mask_53|>": 50313,
"<|mask_54|>": 50314,
"<|mask_55|>": 50315,
"<|mask_56|>": 50316,
"<|mask_57|>": 50317,
"<|mask_58|>": 50318,
"<|mask_59|>": 50319,
"<|mask_5|>": 50265,
"<|mask_60|>": 50320,
"<|mask_61|>": 50321,
"<|mask_62|>": 50322,
"<|mask_63|>": 50323,
"<|mask_64|>": 50324,
"<|mask_65|>": 50325,
"<|mask_66|>": 50326,
"<|mask_67|>": 50327,
"<|mask_68|>": 50328,
"<|mask_69|>": 50329,
"<|mask_6|>": 50266,
"<|mask_70|>": 50330,
"<|mask_71|>": 50331,
"<|mask_72|>": 50332,
"<|mask_73|>": 50333,
"<|mask_74|>": 50334,
"<|mask_75|>": 50335,
"<|mask_76|>": 50336,
"<|mask_77|>": 50337,
"<|mask_78|>": 50338,
"<|mask_79|>": 50339,
"<|mask_7|>": 50267,
"<|mask_80|>": 50340,
"<|mask_81|>": 50341,
"<|mask_82|>": 50342,
"<|mask_83|>": 50343,
"<|mask_84|>": 50344,
"<|mask_85|>": 50345,
"<|mask_86|>": 50346,
"<|mask_87|>": 50347,
"<|mask_88|>": 50348,
"<|mask_89|>": 50349,
"<|mask_8|>": 50268,
"<|mask_90|>": 50350,
"<|mask_91|>": 50351,
"<|mask_92|>": 50352,
"<|mask_93|>": 50353,
"<|mask_94|>": 50354,
"<|mask_95|>": 50355,
"<|mask_96|>": 50356,
"<|mask_97|>": 50357,
"<|mask_98|>": 50358,
"<|mask_99|>": 50359,
"<|mask_9|>": 50269,
"<|r|>": 50257,
"<|s|>": 50258,
"<|x|>": 50259
}
Loading
Loading