Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,78 @@ def __init__(self, job_config: JobConfig):
self.loss_fn, self.gradient_accumulation_steps
)

def llama3_autoparallel_init_fn(model):
# WHC - horrible hack to make auto-parallel work. basically, create a bespoke init_fn for llama3 by copying
# code from the llama3 init_weights functions throughout the model components, and adjusting them to use
# the new FQN structures in autoparallel.
# TODO: make it possible to more easily reuse the existing 'init_weights' functions on the auto_p module
def param(name):
return model.get_parameter(f"params.{name}")

from torchtitan.models.llama3.model.model import precompute_freqs_cis

model.buffers_.get_buffer("freqs_cis").copy_(
DTensor.from_local(
precompute_freqs_cis(
model_args.dim // model_args.n_heads,
model_args.max_seq_len,
model_args.rope_theta,
),
device_mesh=model.buffers_.get_buffer("freqs_cis").device_mesh,
)
)

torch.nn.init.normal_(param("tok_embeddings/weight"))

def init_layer(i):
for norm in ("attention_norm", "ffn_norm"):
torch.nn.init.ones_(param(f"layers/{i}/{norm}/weight"))

if model_args.depth_init:
weight_init_std = 0.02 / (2 * (i + 1)) ** 0.5
else:
weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5

for linear in ("wq", "wk", "wv"):
torch.nn.init.trunc_normal_(
param(f"layers/{i}/attention/{linear}/weight"),
mean=0.0,
std=0.02,
)
torch.nn.init.trunc_normal_(
param(f"layers/{i}/attention/wo/weight"),
mean=0.0,
std=weight_init_std,
)

torch.nn.init.trunc_normal_(
param(f"layers/{i}/feed_forward/w1/weight"), mean=0.0, std=0.02
)
for linear in ("w2", "w3"):
torch.nn.init.trunc_normal_(
param(f"layers/{i}/feed_forward/{linear}/weight"),
mean=0.0,
std=weight_init_std,
)

for i in range(model_args.n_layers):
init_layer(i)

if param("norm/weight") is not None:
torch.nn.init.ones_(param("norm/weight"))

final_out_std = model_args.dim**-0.5
cutoff_factor = 3

if param("output/weight") is not None:
torch.nn.init.trunc_normal_(
param("output/weight"),
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)

# apply parallelisms and initialization
if parallel_dims.pp_enabled:
if not self.train_spec.pipelining_fn:
Expand Down Expand Up @@ -282,7 +354,8 @@ def __init__(self, job_config: JobConfig):

model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
# model.init_weights(buffer_device=buffer_device)
llama3_autoparallel_init_fn(model)
model.train()

self.model_parts = [model]
Expand Down
Loading