Skip to content

Commit 232b49b

Browse files
authored
loar fintune convbert (#997)
1 parent 2658460 commit 232b49b

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from mindnlp.dataset import load_dataset
2+
3+
4+
def get_squad_dataset(tokenizer, batch_size):
5+
# process squad data
6+
def preprocess_function(id, title, context, question, answer):
7+
inputs = tokenizer(
8+
question,
9+
context,
10+
max_length=384,
11+
truncation="only_second",
12+
return_offsets_mapping=True,
13+
padding="max_length",
14+
)
15+
offset_mapping = inputs.pop("offset_mapping")
16+
start_positions = 0
17+
end_positions = 0
18+
19+
answer_start = answer["answer_start"][0]
20+
answer_text = answer["text"][0]
21+
start_char = answer_start
22+
end_char = answer_start + len(answer_text)
23+
sequence_ids = inputs.sequence_ids(0)
24+
25+
idx = 0
26+
while sequence_ids[idx] != 1:
27+
idx += 1
28+
context_start = idx
29+
while sequence_ids[idx] == 1:
30+
idx += 1
31+
context_end = idx - 1
32+
33+
# If the answer is not fully inside the context, label it (0, 0)
34+
if offset_mapping[context_start][0] > end_char or offset_mapping[context_end][1] < start_char:
35+
start_positions = 0
36+
end_positions = 0
37+
else:
38+
# Otherwise it's the start and end token positions
39+
idx = context_start
40+
while idx <= context_end and offset_mapping[idx][0] <= start_char:
41+
idx += 1
42+
start_positions = idx - 1
43+
44+
idx = context_end
45+
while idx >= context_start and offset_mapping[idx][1] >= end_char:
46+
idx -= 1
47+
end_positions = idx + 1
48+
49+
inputs["start_positions"] = start_positions
50+
inputs["end_positions"] = end_positions
51+
return inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"], inputs["start_positions"], inputs["end_positions"]
52+
53+
squad = load_dataset("squad", split="train[:5]")
54+
squad = squad.map(preprocess_function,
55+
input_columns=['id', 'title',
56+
'context', 'question', 'answers'],
57+
output_columns=['input_ids', 'token_type_ids',
58+
'attention_mask', 'start_positions', 'end_positions'],
59+
num_parallel_workers=8)
60+
squad = squad.batch(batch_size)
61+
return squad

llm/peft/train_convbert/train.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import mindspore
3+
from mindspore.nn import AdamWeightDecay
4+
from squad_dataset import get_squad_dataset
5+
from mindnlp.peft import LoraConfig, get_peft_model
6+
from mindnlp.transformers import (
7+
AutoTokenizer,
8+
AutoModelForQuestionAnswering,
9+
)
10+
11+
mindspore.set_context(device_target="CPU")
12+
13+
14+
def main(args):
15+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
16+
model = AutoModelForQuestionAnswering.from_pretrained(
17+
args.model_name_or_path)
18+
19+
ds = get_squad_dataset(tokenizer, args.batch_size)
20+
peft_config = LoraConfig(
21+
lora_alpha=args.lora_alpha,
22+
lora_dropout=args.lora_dropout,
23+
r=args.lora_r,
24+
bias='none',
25+
task_type="QUESTION_ANSWER",
26+
target_modules=args.lora_target_modules.split(","),
27+
)
28+
model = get_peft_model(model=model, peft_config=peft_config)
29+
# model.print_trainable_parameters()
30+
31+
optimizer = AdamWeightDecay(
32+
params=model.trainable_params(), learning_rate=args.lr)
33+
34+
def forward_fn(input_ids, token_type_ids, attention_mask, start_positions, end_positions):
35+
output = model(
36+
input_ids=input_ids,
37+
attention_mask=attention_mask,
38+
token_type_ids=token_type_ids,
39+
start_positions=start_positions,
40+
end_positions=end_positions
41+
)
42+
return output.loss
43+
44+
grad_fn = mindspore.value_and_grad(
45+
forward_fn, None, optimizer.parameters, has_aux=False
46+
)
47+
48+
total_loss, total_step = 0, 0
49+
for _, (input_ids, token_type_ids, attention_mask, start_positions, end_positions) in enumerate(ds):
50+
(loss), grad = grad_fn(input_ids, token_type_ids,
51+
attention_mask, start_positions, end_positions)
52+
optimizer(grad)
53+
total_loss += loss.asnumpy()
54+
total_step += 1
55+
curr_loss = total_loss / total_step
56+
print({"train-loss": f"{curr_loss:.2f}"})
57+
58+
model.save_pretrained(save_directory=args.model_save_dir)
59+
60+
61+
if __name__ == "__main__":
62+
63+
parser = argparse.ArgumentParser()
64+
parser.add_argument("--batch_size", default=4, type=int,
65+
help="Batch size per GPU/CPU for training.")
66+
parser.add_argument("--model_name_or_path", default="YituTech/conv-bert-base",
67+
type=str, help="YituTech/conv-bert-base")
68+
parser.add_argument("--num_epochs", default=5, type=int)
69+
parser.add_argument("--lr", default=1e-4, type=float,
70+
help="Set 2e-5 for full-finetuning.")
71+
parser.add_argument("--max_seq_len", default=256, type=int)
72+
parser.add_argument("--lora_r", type=int, default=32)
73+
parser.add_argument("--lora_alpha", type=int, default=64)
74+
parser.add_argument("--lora_dropout", type=float, default=0)
75+
parser.add_argument("--lora_target_modules", type=str,
76+
default="query, key, value,conv_out_layer, conv_kernel_layer, dense")
77+
parser.add_argument("--model_save_dir", type=str,
78+
default="convbert_lora_peft")
79+
args = parser.parse_args()
80+
main(args)

0 commit comments

Comments
 (0)