-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Please check that this issue hasn't been reported before.
- I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I am trying to run GRPO training on tool calls data set.
The configuation looks like :
base_model: /opt/ml/model/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
chat_template: tokenizer_default
datasets:
- path: /opt/ml/model/grpo_rows.jsonl
type: chat_template
field_messages: prompt
message_field_role: role
message_field_content: content
field_tools: tools
rl: grpo
trl:
use_vllm: true
vllm_server_host: 127.0.0.1
vllm_server_port: 8000
vllm_server_timeout: 300
num_generations: 4
max_completion_length: 6000
rollout_func: ged.runtime.rollout.tool_rollout
reward_funcs:
- ged.runtime.rewards.r_step_keywords
- ged.runtime.rewards.r_mock_success
- ged.runtime.rewards.r_json
reward_weights: [0.4, 0.4, 0.2]
dataset_prepared_path: last_run_prepared
val_set_size: 0.03
eval_steps: 10
When running the training, I am getting error :
Mapping RL Dataset (num_proc=30): 0%| | 0/30 [00:10<?, ? examples/s]
multiprocess.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/multiprocess/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/datasets/utils/py_utils.py", line 586, in _write_generator_to_queue
for i, result in enumerate(func(**kwargs)):
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/datasets/arrow_dataset.py", line 3664, in _map_single
for i, example in iter_outputs(shard_iterable):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/datasets/arrow_dataset.py", line 3638, in iter_outputs
yield i, apply_function(example, i, offset=offset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/datasets/arrow_dataset.py", line 3561, in apply_function
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/axolotl/src/axolotl/prompt_strategies/dpo/chat_template.py", line 63, in transform_fn
chosen_raw = sample[field_chosen]
~~~~~~^^^^^^^^^^^^^^
File "/opt/conda/envs/training-env/lib/python3.12/site-packages/datasets/formatting/formatting.py", line 283, in _getitem_
value = self.data[key]
~~~~~~~~~^^^^^
KeyError: 'chosen'
Should the data set for GRPO include
field_chosen: chosen
field_rejected: rejected # Required for DPO-style
?
As I undersatand chosen/rejected fields are relevant for DPO, not GRPO.
Thank you
Current behaviour
Error raised when trying to train GRPO on tools data set.
Steps to reproduce
- Preparing data set with function calls
- Creating Axolotl configuation as specified with GRPO training.
- Launching training
Config yaml
Possible solution
No response
Which Operating Systems are you using?
- Linux
- macOS
- Windows
Python Version
3.12
axolotl branch-commit
Acknowledgements
- My issue title is concise, descriptive, and in title casing.
- I have searched the existing issues to make sure this bug has not been reported yet.
- I am using the latest version of axolotl.
- I have provided enough information for the maintainers to reproduce and diagnose the issue.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working