diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index 6cf648bc70..da0f80df8f 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -31,7 +31,7 @@ async def test_sql_storage(self, schema_type): config = ExperienceBufferConfig( name="test_storage", schema_type=schema_type, - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value, max_read_timeout=3, path=f"sqlite:///{DB_PATH}", batch_size=self.train_batch_size, @@ -91,7 +91,7 @@ async def test_sql_experience_buffer(self): config = ExperienceBufferConfig( name="test_storage", schema_type="experience", - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value, max_read_timeout=3, path=f"sqlite:///{DB_PATH}", batch_size=self.train_batch_size, diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index d4d3fe1c04..b73ee27579 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -106,7 +106,7 @@ def setUp(self): dataset_config = get_unittest_dataset_config("countdown", "train") self.config.buffer.explorer_input.taskset = dataset_config self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( - name="test_buffer", storage_type=StorageType.FILE + name="test_buffer", storage_type=StorageType.FILE.value ) self.config.check_and_update() ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()}) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 9c8e0ca1f0..537514222f 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -34,7 +34,7 @@ async def test_queue_buffer(self, name, use_priority_queue): config = ExperienceBufferConfig( name=name, schema_type="experience", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_read_timeout=3, path=BUFFER_FILE_PATH, batch_size=self.train_batch_size, @@ -100,7 +100,7 @@ async def test_priority_queue_capacity(self): config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_read_timeout=1, capacity=8, path=BUFFER_FILE_PATH, @@ -160,7 +160,7 @@ async def test_queue_buffer_capacity(self): config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_read_timeout=3, capacity=4, path=BUFFER_FILE_PATH, @@ -191,7 +191,7 @@ async def test_priority_queue_buffer_reuse(self): config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_read_timeout=3, capacity=4, # max total number of items; each item is List[Experience] path=BUFFER_FILE_PATH, @@ -320,7 +320,7 @@ async def test_priority_queue_reuse_count_control(self): config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_read_timeout=3, capacity=4, # max total number of items; each item is List[Experience] path=BUFFER_FILE_PATH, diff --git a/tests/buffer/reader_test.py b/tests/buffer/reader_test.py new file mode 100644 index 0000000000..c1291720b0 --- /dev/null +++ b/tests/buffer/reader_test.py @@ -0,0 +1,28 @@ +from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config +from trinity.buffer.buffer import get_buffer_reader +from trinity.buffer.reader import READER +from trinity.buffer.reader.file_reader import FileReader, TaskFileReader + + +@READER.register_module("custom") +class CustomReader(TaskFileReader): + """A custom reader for testing.""" + + def __init__(self, config): + super().__init__(config) + + +class TestBufferReader(RayUnittestBaseAysnc): + async def test_buffer_reader_registration(self) -> None: + config = get_unittest_dataset_config("countdown", "train") + config.batch_size = 2 + config.storage_type = "custom" + reader = get_buffer_reader(config) + self.assertIsInstance(reader, CustomReader) + tasks = await reader.read_async() + self.assertEqual(len(tasks), 2) + config.storage_type = "file" + reader = get_buffer_reader(config) + self.assertIsInstance(reader, FileReader) + tasks = await reader.read_async() + self.assertEqual(len(tasks), 2) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index d3d7fd47ce..44b81e1495 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -4,9 +4,10 @@ import torch from tests.tools import RayUnittestBaseAysnc +from trinity.buffer import get_buffer_reader from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import ExperienceBufferConfig +from trinity.common.config import ExperienceBufferConfig, TasksetConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -14,7 +15,7 @@ class TestSQLBuffer(RayUnittestBaseAysnc): - async def test_sql_buffer_read_write(self) -> None: + async def test_sql_exp_buffer_read_write(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 @@ -22,7 +23,7 @@ async def test_sql_buffer_read_write(self) -> None: name="test_buffer", schema_type="experience", path=f"sqlite:///{db_path}", - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value, batch_size=read_batch_size, ) sql_writer = SQLWriter(config.to_storage_config()) @@ -62,6 +63,41 @@ async def test_sql_buffer_read_write(self) -> None: self.assertEqual(await sql_writer.release(), 0) self.assertRaises(StopIteration, sql_reader.read) + async def test_sql_task_buffer_read_write(self) -> None: + total_samples = 8 + batch_size = 4 + config = TasksetConfig( + name="test_task_buffer", + path=f"sqlite:///{db_path}", + storage_type=StorageType.SQL.value, + batch_size=batch_size, + default_workflow_type="math_workflow", + ) + sql_writer = SQLWriter(config.to_storage_config()) + tasks = [ + {"question": f"question_{i}", "answer": f"answer_{i}"} for i in range(total_samples) + ] + self.assertEqual(await sql_writer.acquire(), 1) + sql_writer.write(tasks) + sql_reader = get_buffer_reader(config.to_storage_config()) + read_tasks = [] + try: + while True: + cur_tasks = sql_reader.read() + read_tasks.extend(cur_tasks) + except StopIteration: + pass + self.assertEqual(len(read_tasks), total_samples) + self.assertIn("question", read_tasks[0].raw_task) + self.assertIn("answer", read_tasks[0].raw_task) + db_wrapper = ray.get_actor("sql-test_task_buffer") + self.assertIsNotNone(db_wrapper) + self.assertEqual(await sql_writer.release(), 0) + def setUp(self) -> None: if os.path.exists(db_path): os.remove(db_path) + + def tearDown(self) -> None: + if os.path.exists(db_path): + os.remove(db_path) diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index bff905a3e1..6b5785a48e 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -5,12 +5,22 @@ from parameterized import parameterized -from tests.tools import get_template_config -from trinity.buffer.task_scheduler import TasksetScheduler +from tests.tools import get_template_config, get_unittest_dataset_config +from trinity.buffer.reader import READER +from trinity.buffer.reader.file_reader import TaskFileReader +from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig from trinity.common.workflows.workflow import Task +@READER.register_module("custom_reader") +class CustomReader(TaskFileReader): + """A custom reader for testing.""" + + def __init__(self, config): + super().__init__(config) + + class TestTaskScheduler(unittest.IsolatedAsyncioTestCase): temp_output_path = "tmp/test_task_scheduler/" @@ -313,3 +323,36 @@ async def test_task_scheduler( with self.assertRaises(StopAsyncIteration): batch_tasks = await task_scheduler.read_async() + + async def test_task_scheduler_simple(self): + config = get_template_config() + config.mode = "explore" + config.buffer.batch_size = 4 + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown", "train") + config.buffer.explorer_input.taskset.storage_type = "custom_reader" + config.check_and_update() + + task_scheduler = get_taskset_scheduler({}, config) + + batch_tasks = await task_scheduler.read_async() + self.assertEqual(len(batch_tasks), 4) + task_scheduler_state = task_scheduler.state_dict() + self.assertEqual(len(task_scheduler_state), 1) + self.assertEqual(task_scheduler_state[0]["current_index"], 4) + # no effect + task_scheduler.update({"metric1": 0.5}) + + task_scheduler = get_taskset_scheduler( + { + "latest_iteration": 1, + "taskset_states": [ + {"current_index": 8}, + ], + }, + config, + ) + batch_tasks = await task_scheduler.read_async() + self.assertEqual(len(batch_tasks), 4) + task_scheduler_state = task_scheduler.state_dict() + self.assertEqual(len(task_scheduler_state), 1) + self.assertEqual(task_scheduler_state[0]["current_index"], 12) diff --git a/tests/buffer/task_storage_test.py b/tests/buffer/task_storage_test.py index bcb1767ef2..a36fb7f45e 100644 --- a/tests/buffer/task_storage_test.py +++ b/tests/buffer/task_storage_test.py @@ -18,12 +18,12 @@ class TaskStorageTest(RayUnittestBase): @parameterized.expand( [ - (StorageType.FILE, True, 2), - (StorageType.SQL, True, 2), - (StorageType.FILE, False, 0), - (StorageType.SQL, False, 0), - (StorageType.FILE, False, 2), - (StorageType.SQL, False, 2), + (StorageType.FILE.value, True, 2), + (StorageType.SQL.value, True, 2), + (StorageType.FILE.value, False, 0), + (StorageType.SQL.value, False, 0), + (StorageType.FILE.value, False, 2), + (StorageType.SQL.value, False, 2), ] ) def test_read_task(self, storage_type, is_eval, offset): @@ -37,7 +37,7 @@ def test_read_task(self, storage_type, is_eval, offset): config.buffer.explorer_input.taskset.is_eval = is_eval config.buffer.explorer_input.taskset.index = offset config.buffer.explorer_input.taskset.batch_size = batch_size - if storage_type == StorageType.SQL: + if storage_type == StorageType.SQL.value: dataset = datasets.load_dataset( config.buffer.explorer_input.taskset.path, split="train" ) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 91222e9884..3470ed3456 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -184,7 +184,7 @@ def setUp(self): self.config.explorer.service_status_check_interval = 30 self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="experience_buffer", - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value, ) self.config.check_and_update() if multiprocessing.get_start_method(allow_none=True) != "spawn": diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 306cc8e447..620f775e9a 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -233,7 +233,7 @@ def setUp(self): self.config.buffer.trainer_input.experience_buffer ) = ExperienceBufferConfig( name="test", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, schema_type="experience", path="", ) diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index b8f9a9bdb8..596bf5b839 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -132,7 +132,7 @@ def test_synchronizer(self): config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER @@ -151,7 +151,7 @@ def test_synchronizer(self): explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 explorer1_config.buffer.explorer_output = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) explorer1_config.check_and_update() @@ -255,7 +255,7 @@ def test_synchronizer(self): config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) config.synchronizer.sync_method = self.sync_method config.synchronizer.sync_style = self.sync_style @@ -275,7 +275,7 @@ def test_synchronizer(self): explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 explorer1_config.buffer.explorer_output = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) explorer2_config = deepcopy(explorer1_config) explorer2_config.explorer.name = "explorer2" @@ -356,7 +356,7 @@ def test_synchronizer(self): config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) config.synchronizer.sync_method = SyncMethod.NCCL config.synchronizer.sync_style = self.sync_style diff --git a/tests/tools.py b/tests/tools.py index 62a5b4d92e..7c8713319e 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -18,6 +18,7 @@ CHECKPOINT_ROOT_DIR_ENV_VAR, MODEL_PATH_ENV_VAR, PromptType, + StorageType, ) API_MODEL_PATH_ENV_VAR = "TRINITY_API_MODEL_PATH" @@ -133,6 +134,7 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"), split="train", + storage_type=StorageType.FILE.value, schema_type="sft", format=FormatConfig( prompt_type=PromptType.PLAINTEXT, @@ -146,6 +148,7 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"), split="train", + storage_type=StorageType.FILE.value, format=FormatConfig( prompt_type=PromptType.MESSAGES, messages_key="messages", @@ -159,6 +162,8 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "human_like"), split="train", + storage_type=StorageType.FILE.value, + schema_type="dpo", format=FormatConfig( prompt_type=PromptType.PLAINTEXT, prompt_key="prompt", diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 07f35b6219..ba12d6a1bf 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -305,7 +305,7 @@ def test_trainer(self, mock_load): experience_buffer=ExperienceBufferConfig( name="test_queue_storage", max_read_timeout=20, - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, max_retry_times=10, ) ), @@ -512,7 +512,7 @@ def test_fully_async_mode(self): config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) config.buffer.trainer_input.experience_buffer.replay_buffer.enable = self.use_priority_queue config.synchronizer.sync_method = SyncMethod.CHECKPOINT @@ -541,7 +541,7 @@ def test_fully_async_mode(self): explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 explorer1_config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) explorer2_config = deepcopy(explorer1_config) explorer2_config.trainer = deepcopy(trainer_config.trainer) diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 46929f06be..2eab5791b2 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -4,6 +4,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.buffer_writer import BufferWriter +from trinity.buffer.reader import READER from trinity.common.config import ExperienceBufferConfig, StorageConfig, TasksetConfig from trinity.common.constants import StorageType @@ -16,28 +17,10 @@ def get_buffer_reader(config: BufferStorageConfig) -> BufferReader: storage_config: StorageConfig = config.to_storage_config() else: storage_config = config - if storage_config.storage_type == StorageType.SQL: - from trinity.buffer.reader.sql_reader import SQLReader - - return SQLReader(storage_config) - elif storage_config.storage_type == StorageType.QUEUE: - from trinity.buffer.reader.queue_reader import QueueReader - - return QueueReader(storage_config) - elif storage_config.storage_type == StorageType.FILE: - from trinity.buffer.reader.file_reader import ( - ExperienceFileReader, - TaskFileReader, - ) - - schema_type = storage_config.schema_type - if schema_type: - # only trainer input has schema type - return ExperienceFileReader(storage_config) - else: - return TaskFileReader(storage_config) - else: + reader_cls = READER.get(storage_config.storage_type) + if reader_cls is None: raise ValueError(f"{storage_config.storage_type} not supported.") + return reader_cls(storage_config) def get_buffer_writer(config: BufferStorageConfig) -> BufferWriter: @@ -46,15 +29,15 @@ def get_buffer_writer(config: BufferStorageConfig) -> BufferWriter: storage_config: StorageConfig = config.to_storage_config() else: storage_config = config - if storage_config.storage_type == StorageType.SQL: + if storage_config.storage_type == StorageType.SQL.value: from trinity.buffer.writer.sql_writer import SQLWriter return SQLWriter(storage_config) - elif storage_config.storage_type == StorageType.QUEUE: + elif storage_config.storage_type == StorageType.QUEUE.value: from trinity.buffer.writer.queue_writer import QueueWriter return QueueWriter(storage_config) - elif storage_config.storage_type == StorageType.FILE: + elif storage_config.storage_type == StorageType.FILE.value: from trinity.buffer.writer.file_writer import JSONWriter return JSONWriter(storage_config) diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index 5315bf7ecf..d47d80ace1 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -19,7 +19,12 @@ def __len__(self) -> int: raise NotImplementedError def state_dict(self) -> Dict: - return {} + """Return the state of the reader as a dict. + Returns: + A dict containing the reader state. At minimum, it should contain + the `current_index` field. + """ + raise NotImplementedError def load_state_dict(self, state_dict: Dict) -> None: - pass + raise NotImplementedError diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index fc587fbb9b..e978ac9ace 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -58,7 +58,7 @@ def _init_input_storage( elif is_json_file(pipeline_config.input_save_path): return get_buffer_writer( StorageConfig( - storage_type=StorageType.FILE, + storage_type=StorageType.FILE.value, path=pipeline_config.input_save_path, wrap_in_ray=False, ), @@ -66,7 +66,7 @@ def _init_input_storage( elif is_database_url(pipeline_config.input_save_path): return get_buffer_writer( StorageConfig( - storage_type=StorageType.SQL, + storage_type=StorageType.SQL.value, path=pipeline_config.input_save_path, wrap_in_ray=False, ), diff --git a/trinity/buffer/reader/__init__.py b/trinity/buffer/reader/__init__.py index e69de29bb2..49d3bbe3d2 100644 --- a/trinity/buffer/reader/__init__.py +++ b/trinity/buffer/reader/__init__.py @@ -0,0 +1,6 @@ +from trinity.buffer.reader.file_reader import FileReader +from trinity.buffer.reader.queue_reader import QueueReader +from trinity.buffer.reader.reader import READER +from trinity.buffer.reader.sql_reader import SQLReader + +__all__ = ["READER", "FileReader", "QueueReader", "SQLReader"] diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index b6f39979f4..c61e25e831 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -6,6 +6,7 @@ from datasets import Dataset, load_dataset from trinity.buffer.buffer_reader import BufferReader +from trinity.buffer.reader.reader import READER from trinity.buffer.schema.formatter import FORMATTER from trinity.common.config import StorageConfig @@ -85,9 +86,6 @@ def select_batch(self, indices: List[int]) -> List: class BaseFileReader(BufferReader): - def __len__(self): - return self.dataset.dataset_size - async def read_async(self, batch_size: Optional[int] = None): try: return self.read(batch_size) @@ -95,6 +93,37 @@ async def read_async(self, batch_size: Optional[int] = None): raise StopAsyncIteration from e +@READER.register_module("file") +class FileReader(BaseFileReader): + """Provide a unified interface for Experience and Task file readers.""" + + def __init__(self, config: StorageConfig): + if config.schema_type and config.schema_type != "task": + self.reader = ExperienceFileReader(config) + else: + self.reader = TaskFileReader(config) + + def read(self, batch_size: Optional[int] = None) -> List: + return self.reader.read(batch_size) + + def read_with_indices(self, indices: List[int]) -> List: + """Read tasks with indices.""" + return self.reader.read_with_indices(indices) + + async def read_with_indices_async(self, indices: List[int]) -> List: + """Read tasks with indices asynchronously.""" + return await self.reader.read_with_indices_async(indices) + + def state_dict(self): + return self.reader.state_dict() + + def load_state_dict(self, state_dict): + return self.reader.load_state_dict(state_dict) + + def __len__(self): + return self.reader.__len__() + + class ExperienceFileReader(BaseFileReader): """Reader for SFT / DPO file data.""" @@ -121,6 +150,15 @@ def read(self, batch_size: Optional[int] = None) -> List: exp_list.append(experience) return exp_list + def state_dict(self): + return {"current_index": self.dataset.current_offset} + + def load_state_dict(self, state_dict): + self.dataset.current_offset = state_dict["current_index"] + + def __len__(self): + return self.dataset.dataset_size + class TaskFileReader(BaseFileReader): """A Reader for task file data.""" @@ -164,3 +202,12 @@ def read_with_indices(self, indices: List[int]) -> List: async def read_with_indices_async(self, indices: List[int]) -> List: """Read tasks with indices asynchronously.""" return self.read_with_indices(indices) + + def state_dict(self): + return {"current_index": self.dataset.current_offset} + + def load_state_dict(self, state_dict): + self.dataset.current_offset = state_dict["current_index"] + + def __len__(self): + return self.dataset.dataset_size diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index e36dfc0ce9..31ab44d43a 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -1,20 +1,22 @@ """Reader of the Queue buffer.""" -from typing import List, Optional +from typing import Dict, List, Optional import ray from trinity.buffer.buffer_reader import BufferReader +from trinity.buffer.reader.reader import READER from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType +@READER.register_module("queue") class QueueReader(BufferReader): """Reader of the Queue buffer.""" def __init__(self, config: StorageConfig): - assert config.storage_type == StorageType.QUEUE + assert config.storage_type == StorageType.QUEUE.value self.timeout = config.max_read_timeout self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) @@ -39,3 +41,11 @@ async def read_async(self, batch_size: Optional[int] = None) -> List: f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." ) return exps + + def state_dict(self) -> Dict: + # SQL Not supporting state dict yet + return {"current_index": 0} + + def load_state_dict(self, state_dict): + # SQL Not supporting state dict yet + return None diff --git a/trinity/buffer/reader/reader.py b/trinity/buffer/reader/reader.py new file mode 100644 index 0000000000..63da7b48e3 --- /dev/null +++ b/trinity/buffer/reader/reader.py @@ -0,0 +1,3 @@ +from trinity.utils.registry import Registry + +READER = Registry("reader") diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index d44d2f244f..d13feeed7f 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -1,20 +1,22 @@ """Reader of the SQL buffer.""" -from typing import List, Optional +from typing import Dict, List, Optional import ray from trinity.buffer.buffer_reader import BufferReader +from trinity.buffer.reader.reader import READER from trinity.buffer.storage.sql import SQLStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType +@READER.register_module("sql") class SQLReader(BufferReader): """Reader of the SQL buffer.""" def __init__(self, config: StorageConfig) -> None: - assert config.storage_type == StorageType.SQL + assert config.storage_type == StorageType.SQL.value self.wrap_in_ray = config.wrap_in_ray self.storage = SQLStorage.get_wrapper(config) @@ -32,3 +34,11 @@ async def read_async(self, batch_size: Optional[int] = None) -> List: raise StopAsyncIteration else: return self.storage.read(batch_size) + + def state_dict(self) -> Dict: + # SQL Not supporting state dict yet + return {"current_index": 0} + + def load_state_dict(self, state_dict): + # SQL Not supporting state dict yet + return None diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 2d924d362e..523cde5b18 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -309,12 +309,12 @@ def __init__(self, config: StorageConfig) -> None: if is_database_url(st_config.path): from trinity.buffer.writer.sql_writer import SQLWriter - st_config.storage_type = StorageType.SQL + st_config.storage_type = StorageType.SQL.value self.writer = SQLWriter(st_config) elif is_json_file(st_config.path): from trinity.buffer.writer.file_writer import JSONWriter - st_config.storage_type = StorageType.FILE + st_config.storage_type = StorageType.FILE.value self.writer = JSONWriter(st_config) else: self.logger.warning("Unknown supported storage path: %s", st_config.path) @@ -322,7 +322,7 @@ def __init__(self, config: StorageConfig) -> None: else: from trinity.buffer.writer.file_writer import JSONWriter - st_config.storage_type = StorageType.FILE + st_config.storage_type = StorageType.FILE.value self.writer = JSONWriter(st_config) self.logger.warning(f"Save experiences in {st_config.path}.") self.ref_count = 0 diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 01b6fa1a47..713f836e90 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -13,8 +13,77 @@ from trinity.utils.annotations import Experimental +def get_taskset_scheduler(explorer_state: Dict, config: Config) -> "TasksetSchedulerBase": + """Get a taskset scheduler according to the config. + + Args: + explorer_state (Dict): Restoration state from checkpoint (may include progress info) + config (Config): Full system configuration containing buffer and taskset settings + + Returns: + TasksetSchedulerBase: The taskset scheduler instance + """ + taskset_configs = config.buffer.explorer_input.tasksets + if len(taskset_configs) == 1 and taskset_configs[0].task_selector.selector_type == "sequential": + return SimpleTasksetScheduler(explorer_state, config) + else: + return TasksetScheduler(explorer_state, config) + + +class TasksetSchedulerBase: + def __init__(self, explorer_state: Dict, config: Config): + self.config = config + self.explorer_state = explorer_state + + async def read_async(self) -> List: + """Asynchronously reads a batch of tasks according to the current schedule.""" + raise NotImplementedError + + def state_dict(self) -> List[Dict]: + """return persistent state for checkpointing. + + Returns: + List[Dict]: State dicts for all selectors (one per taskset) + """ + raise NotImplementedError + + def update(self, pipeline_metrics: Dict) -> None: + """Update selectors using feedback from the training pipeline.""" + raise NotImplementedError + + +class SimpleTasksetScheduler(TasksetSchedulerBase): + """ + A simple taskset scheduler that only reads from one taskset without task selection strategies. + """ + + def __init__(self, explorer_state: Dict, config: Config): + super().__init__(explorer_state, config) + if "latest_task_index" in self.explorer_state: + self.explorer_state["taskset_states"] = [ + { + "current_index": explorer_state["latest_task_index"], + } + ] + index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get( + "current_index", 0 + ) + self.config.buffer.explorer_input.tasksets[0].index = index + self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) + + async def read_async(self) -> List: + return await self.reader.read_async() + + def state_dict(self) -> List[Dict]: + return [self.reader.state_dict()] + + def update(self, pipeline_metrics: Dict) -> None: + # do nothing here + return + + @Experimental -class TasksetScheduler: +class TasksetScheduler(TasksetSchedulerBase): """ Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset. @@ -38,7 +107,7 @@ def __init__(self, explorer_state: Dict, config: Config): explorer_state (Dict): Restoration state from checkpoint (may include progress info) config (Config): Full system configuration containing buffer and taskset settings """ - self.config = config + super().__init__(explorer_state, config) # Backward compatibility: old format stored 'latest_task_index' directly if "latest_task_index" in explorer_state: @@ -52,7 +121,7 @@ def __init__(self, explorer_state: Dict, config: Config): self.read_batch_size = config.buffer.batch_size taskset_configs = config.buffer.explorer_input.tasksets - from trinity.buffer.reader.file_reader import TaskFileReader + from trinity.buffer.reader.file_reader import FileReader taskset_states = explorer_state.get( "taskset_states", [{"current_index": 0}] * len(taskset_configs) @@ -62,15 +131,15 @@ def __init__(self, explorer_state: Dict, config: Config): for taskset_config, taskset_state in zip(taskset_configs, taskset_states): assert not taskset_config.is_eval # assume drop last taskset = get_buffer_reader(taskset_config) - if not isinstance(taskset, TaskFileReader): + if not isinstance(taskset, FileReader): raise TypeError( f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'." - f"Currently, only 'TaskFileReader' is supported by TasksetScheduler." + f"Currently, only 'FileReader' is supported by TasksetScheduler." ) # Create selector based on type specified in config (e.g., 'sequential', 'shuffle') selector = SELECTORS.get(taskset_config.task_selector.selector_type)( - taskset.dataset, taskset_config.task_selector + taskset.reader.dataset, taskset_config.task_selector ) selector.load_state_dict(taskset_state) # Restore any prior state diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 9ccbd718f0..cfb5753bb0 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -10,7 +10,7 @@ class JSONWriter(BufferWriter): def __init__(self, config: StorageConfig): - assert config.storage_type == StorageType.FILE + assert config.storage_type == StorageType.FILE.value self.writer = FileStorage.get_wrapper(config) self.wrap_in_ray = config.wrap_in_ray diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 2d62511c90..951f5445b7 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -13,7 +13,7 @@ class QueueWriter(BufferWriter): """Writer of the Queue buffer.""" def __init__(self, config: StorageConfig): - assert config.storage_type == StorageType.QUEUE + assert config.storage_type == StorageType.QUEUE.value self.queue = QueueStorage.get_wrapper(config) def write(self, data: List) -> None: diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index eeec7be55e..4b333af09c 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -12,7 +12,7 @@ class SQLWriter(BufferWriter): """Writer of the SQL buffer.""" def __init__(self, config: StorageConfig) -> None: - assert config.storage_type == StorageType.SQL + assert config.storage_type == StorageType.SQL.value # we only support write RFT algorithm buffer for now self.wrap_in_ray = config.wrap_in_ray self.db_wrapper = SQLStorage.get_wrapper(config) diff --git a/trinity/common/config.py b/trinity/common/config.py index bbf5fe7ba2..53755f863c 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -150,7 +150,7 @@ class StorageConfig: """ name: str = "" - storage_type: StorageType = StorageType.FILE + storage_type: str = StorageType.FILE.value path: Optional[str] = None repeat_times: Optional[int] = None @@ -210,7 +210,7 @@ class StorageConfig: @dataclass class TasksetConfig: name: str = "" - storage_type: StorageType = StorageType.FILE + storage_type: str = StorageType.FILE.value path: Optional[str] = None default_workflow_type: Optional[str] = None @@ -276,7 +276,7 @@ class ExperienceBufferConfig: """Storage Config for trainer input experience buffer.""" name: str = "" - storage_type: StorageType = StorageType.FILE + storage_type: str = StorageType.QUEUE.value path: Optional[str] = None # used for StorageType.QUEUE @@ -923,14 +923,14 @@ def _check_trainer_input(self) -> None: if experience_buffer is None: experience_buffer = trainer_input.experience_buffer = ExperienceBufferConfig( name="experience_buffer", - storage_type=StorageType.QUEUE, + storage_type=StorageType.QUEUE.value, ) logger.info(f"Auto set `buffer.trainer_input.experience_buffer` to {experience_buffer}") - elif experience_buffer.storage_type is StorageType.FILE and self.mode == "both": + elif experience_buffer.storage_type == StorageType.FILE.value and self.mode == "both": logger.warning( "`FILE` storage is not supported to use as experience_buffer in `both` mode, use `QUEUE` instead." ) - experience_buffer.storage_type = StorageType.QUEUE + experience_buffer.storage_type = StorageType.QUEUE.value if not experience_buffer.name: experience_buffer.name = "experience_buffer" @@ -967,8 +967,8 @@ def _check_trainer_input(self) -> None: experience_buffer.total_epochs = self.buffer.total_epochs experience_buffer.total_steps = self.buffer.total_steps - def _default_storage_path(self, storage_type: StorageType, name: str) -> str: - if storage_type == StorageType.SQL: + def _default_storage_path(self, storage_type: str, name: str) -> str: + if storage_type == StorageType.SQL.value: return "sqlite:///" + os.path.join(self.buffer.cache_dir, f"{name}.db") # type: ignore[arg-type] else: return os.path.join(self.buffer.cache_dir, f"{name}.jsonl") # type: ignore[arg-type] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 6166e685fa..7fbd2eb72d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -15,7 +15,7 @@ from trinity.buffer.buffer import get_buffer_reader from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline -from trinity.buffer.task_scheduler import TasksetScheduler +from trinity.buffer.task_scheduler import get_taskset_scheduler from trinity.common.config import Config from trinity.common.constants import ( ROLLOUT_WEIGHT_SYNC_GROUP_NAME, @@ -52,7 +52,7 @@ def __init__(self, config: Config): self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( - TasksetScheduler(explorer_state, config) + get_taskset_scheduler(explorer_state=explorer_state, config=config) if self.config.mode not in {"bench", "serve"} else None ) @@ -280,8 +280,8 @@ async def eval(self): f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started." ) eval_taskset = get_buffer_reader(eval_taskset_config) - eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}" - self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name)) + eval_batch_id = f"{self.explore_step_num}/{eval_taskset_config.name}" + self.pending_eval_tasks.append((self.explore_step_num, eval_taskset_config.name)) while True: try: data = await eval_taskset.read_async()