Skip to content

Commit 980cffc

Browse files
authored
Make buffer reader registerable (#395)
1 parent 90088fb commit 980cffc

27 files changed

+334
-89
lines changed

tests/buffer/experience_storage_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_sql_storage(self, schema_type):
3131
config = ExperienceBufferConfig(
3232
name="test_storage",
3333
schema_type=schema_type,
34-
storage_type=StorageType.SQL,
34+
storage_type=StorageType.SQL.value,
3535
max_read_timeout=3,
3636
path=f"sqlite:///{DB_PATH}",
3737
batch_size=self.train_batch_size,
@@ -91,7 +91,7 @@ async def test_sql_experience_buffer(self):
9191
config = ExperienceBufferConfig(
9292
name="test_storage",
9393
schema_type="experience",
94-
storage_type=StorageType.SQL,
94+
storage_type=StorageType.SQL.value,
9595
max_read_timeout=3,
9696
path=f"sqlite:///{DB_PATH}",
9797
batch_size=self.train_batch_size,

tests/buffer/file_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def setUp(self):
106106
dataset_config = get_unittest_dataset_config("countdown", "train")
107107
self.config.buffer.explorer_input.taskset = dataset_config
108108
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
109-
name="test_buffer", storage_type=StorageType.FILE
109+
name="test_buffer", storage_type=StorageType.FILE.value
110110
)
111111
self.config.check_and_update()
112112
ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()})

tests/buffer/queue_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def test_queue_buffer(self, name, use_priority_queue):
3434
config = ExperienceBufferConfig(
3535
name=name,
3636
schema_type="experience",
37-
storage_type=StorageType.QUEUE,
37+
storage_type=StorageType.QUEUE.value,
3838
max_read_timeout=3,
3939
path=BUFFER_FILE_PATH,
4040
batch_size=self.train_batch_size,
@@ -100,7 +100,7 @@ async def test_priority_queue_capacity(self):
100100
config = ExperienceBufferConfig(
101101
name="test_buffer_small",
102102
schema_type="experience",
103-
storage_type=StorageType.QUEUE,
103+
storage_type=StorageType.QUEUE.value,
104104
max_read_timeout=1,
105105
capacity=8,
106106
path=BUFFER_FILE_PATH,
@@ -160,7 +160,7 @@ async def test_queue_buffer_capacity(self):
160160
config = ExperienceBufferConfig(
161161
name="test_buffer_small",
162162
schema_type="experience",
163-
storage_type=StorageType.QUEUE,
163+
storage_type=StorageType.QUEUE.value,
164164
max_read_timeout=3,
165165
capacity=4,
166166
path=BUFFER_FILE_PATH,
@@ -191,7 +191,7 @@ async def test_priority_queue_buffer_reuse(self):
191191
config = ExperienceBufferConfig(
192192
name="test_buffer_small",
193193
schema_type="experience",
194-
storage_type=StorageType.QUEUE,
194+
storage_type=StorageType.QUEUE.value,
195195
max_read_timeout=3,
196196
capacity=4, # max total number of items; each item is List[Experience]
197197
path=BUFFER_FILE_PATH,
@@ -320,7 +320,7 @@ async def test_priority_queue_reuse_count_control(self):
320320
config = ExperienceBufferConfig(
321321
name="test_buffer_small",
322322
schema_type="experience",
323-
storage_type=StorageType.QUEUE,
323+
storage_type=StorageType.QUEUE.value,
324324
max_read_timeout=3,
325325
capacity=4, # max total number of items; each item is List[Experience]
326326
path=BUFFER_FILE_PATH,

tests/buffer/reader_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config
2+
from trinity.buffer.buffer import get_buffer_reader
3+
from trinity.buffer.reader import READER
4+
from trinity.buffer.reader.file_reader import FileReader, TaskFileReader
5+
6+
7+
@READER.register_module("custom")
8+
class CustomReader(TaskFileReader):
9+
"""A custom reader for testing."""
10+
11+
def __init__(self, config):
12+
super().__init__(config)
13+
14+
15+
class TestBufferReader(RayUnittestBaseAysnc):
16+
async def test_buffer_reader_registration(self) -> None:
17+
config = get_unittest_dataset_config("countdown", "train")
18+
config.batch_size = 2
19+
config.storage_type = "custom"
20+
reader = get_buffer_reader(config)
21+
self.assertIsInstance(reader, CustomReader)
22+
tasks = await reader.read_async()
23+
self.assertEqual(len(tasks), 2)
24+
config.storage_type = "file"
25+
reader = get_buffer_reader(config)
26+
self.assertIsInstance(reader, FileReader)
27+
tasks = await reader.read_async()
28+
self.assertEqual(len(tasks), 2)

tests/buffer/sql_test.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,26 @@
44
import torch
55

66
from tests.tools import RayUnittestBaseAysnc
7+
from trinity.buffer import get_buffer_reader
78
from trinity.buffer.reader.sql_reader import SQLReader
89
from trinity.buffer.writer.sql_writer import SQLWriter
9-
from trinity.common.config import ExperienceBufferConfig
10+
from trinity.common.config import ExperienceBufferConfig, TasksetConfig
1011
from trinity.common.constants import StorageType
1112
from trinity.common.experience import Experience
1213

1314
db_path = os.path.join(os.path.dirname(__file__), "test.db")
1415

1516

1617
class TestSQLBuffer(RayUnittestBaseAysnc):
17-
async def test_sql_buffer_read_write(self) -> None:
18+
async def test_sql_exp_buffer_read_write(self) -> None:
1819
total_num = 8
1920
put_batch_size = 2
2021
read_batch_size = 4
2122
config = ExperienceBufferConfig(
2223
name="test_buffer",
2324
schema_type="experience",
2425
path=f"sqlite:///{db_path}",
25-
storage_type=StorageType.SQL,
26+
storage_type=StorageType.SQL.value,
2627
batch_size=read_batch_size,
2728
)
2829
sql_writer = SQLWriter(config.to_storage_config())
@@ -62,6 +63,41 @@ async def test_sql_buffer_read_write(self) -> None:
6263
self.assertEqual(await sql_writer.release(), 0)
6364
self.assertRaises(StopIteration, sql_reader.read)
6465

66+
async def test_sql_task_buffer_read_write(self) -> None:
67+
total_samples = 8
68+
batch_size = 4
69+
config = TasksetConfig(
70+
name="test_task_buffer",
71+
path=f"sqlite:///{db_path}",
72+
storage_type=StorageType.SQL.value,
73+
batch_size=batch_size,
74+
default_workflow_type="math_workflow",
75+
)
76+
sql_writer = SQLWriter(config.to_storage_config())
77+
tasks = [
78+
{"question": f"question_{i}", "answer": f"answer_{i}"} for i in range(total_samples)
79+
]
80+
self.assertEqual(await sql_writer.acquire(), 1)
81+
sql_writer.write(tasks)
82+
sql_reader = get_buffer_reader(config.to_storage_config())
83+
read_tasks = []
84+
try:
85+
while True:
86+
cur_tasks = sql_reader.read()
87+
read_tasks.extend(cur_tasks)
88+
except StopIteration:
89+
pass
90+
self.assertEqual(len(read_tasks), total_samples)
91+
self.assertIn("question", read_tasks[0].raw_task)
92+
self.assertIn("answer", read_tasks[0].raw_task)
93+
db_wrapper = ray.get_actor("sql-test_task_buffer")
94+
self.assertIsNotNone(db_wrapper)
95+
self.assertEqual(await sql_writer.release(), 0)
96+
6597
def setUp(self) -> None:
6698
if os.path.exists(db_path):
6799
os.remove(db_path)
100+
101+
def tearDown(self) -> None:
102+
if os.path.exists(db_path):
103+
os.remove(db_path)

tests/buffer/task_scheduler_test.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@
55

66
from parameterized import parameterized
77

8-
from tests.tools import get_template_config
9-
from trinity.buffer.task_scheduler import TasksetScheduler
8+
from tests.tools import get_template_config, get_unittest_dataset_config
9+
from trinity.buffer.reader import READER
10+
from trinity.buffer.reader.file_reader import TaskFileReader
11+
from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler
1012
from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig
1113
from trinity.common.workflows.workflow import Task
1214

1315

16+
@READER.register_module("custom_reader")
17+
class CustomReader(TaskFileReader):
18+
"""A custom reader for testing."""
19+
20+
def __init__(self, config):
21+
super().__init__(config)
22+
23+
1424
class TestTaskScheduler(unittest.IsolatedAsyncioTestCase):
1525
temp_output_path = "tmp/test_task_scheduler/"
1626

@@ -313,3 +323,36 @@ async def test_task_scheduler(
313323

314324
with self.assertRaises(StopAsyncIteration):
315325
batch_tasks = await task_scheduler.read_async()
326+
327+
async def test_task_scheduler_simple(self):
328+
config = get_template_config()
329+
config.mode = "explore"
330+
config.buffer.batch_size = 4
331+
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown", "train")
332+
config.buffer.explorer_input.taskset.storage_type = "custom_reader"
333+
config.check_and_update()
334+
335+
task_scheduler = get_taskset_scheduler({}, config)
336+
337+
batch_tasks = await task_scheduler.read_async()
338+
self.assertEqual(len(batch_tasks), 4)
339+
task_scheduler_state = task_scheduler.state_dict()
340+
self.assertEqual(len(task_scheduler_state), 1)
341+
self.assertEqual(task_scheduler_state[0]["current_index"], 4)
342+
# no effect
343+
task_scheduler.update({"metric1": 0.5})
344+
345+
task_scheduler = get_taskset_scheduler(
346+
{
347+
"latest_iteration": 1,
348+
"taskset_states": [
349+
{"current_index": 8},
350+
],
351+
},
352+
config,
353+
)
354+
batch_tasks = await task_scheduler.read_async()
355+
self.assertEqual(len(batch_tasks), 4)
356+
task_scheduler_state = task_scheduler.state_dict()
357+
self.assertEqual(len(task_scheduler_state), 1)
358+
self.assertEqual(task_scheduler_state[0]["current_index"], 12)

tests/buffer/task_storage_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
class TaskStorageTest(RayUnittestBase):
1919
@parameterized.expand(
2020
[
21-
(StorageType.FILE, True, 2),
22-
(StorageType.SQL, True, 2),
23-
(StorageType.FILE, False, 0),
24-
(StorageType.SQL, False, 0),
25-
(StorageType.FILE, False, 2),
26-
(StorageType.SQL, False, 2),
21+
(StorageType.FILE.value, True, 2),
22+
(StorageType.SQL.value, True, 2),
23+
(StorageType.FILE.value, False, 0),
24+
(StorageType.SQL.value, False, 0),
25+
(StorageType.FILE.value, False, 2),
26+
(StorageType.SQL.value, False, 2),
2727
]
2828
)
2929
def test_read_task(self, storage_type, is_eval, offset):
@@ -37,7 +37,7 @@ def test_read_task(self, storage_type, is_eval, offset):
3737
config.buffer.explorer_input.taskset.is_eval = is_eval
3838
config.buffer.explorer_input.taskset.index = offset
3939
config.buffer.explorer_input.taskset.batch_size = batch_size
40-
if storage_type == StorageType.SQL:
40+
if storage_type == StorageType.SQL.value:
4141
dataset = datasets.load_dataset(
4242
config.buffer.explorer_input.taskset.path, split="train"
4343
)

tests/explorer/explorer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def setUp(self):
184184
self.config.explorer.service_status_check_interval = 30
185185
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
186186
name="experience_buffer",
187-
storage_type=StorageType.SQL,
187+
storage_type=StorageType.SQL.value,
188188
)
189189
self.config.check_and_update()
190190
if multiprocessing.get_start_method(allow_none=True) != "spawn":

tests/explorer/scheduler_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def setUp(self):
253253
self.config.buffer.trainer_input.experience_buffer
254254
) = ExperienceBufferConfig(
255255
name="test",
256-
storage_type=StorageType.QUEUE,
256+
storage_type=StorageType.QUEUE.value,
257257
schema_type="experience",
258258
path="",
259259
)

tests/manager/synchronizer_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_synchronizer(self):
132132
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
133133
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
134134
name="exp_buffer",
135-
storage_type=StorageType.QUEUE,
135+
storage_type=StorageType.QUEUE.value,
136136
)
137137
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
138138
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
@@ -151,7 +151,7 @@ def test_synchronizer(self):
151151
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
152152
explorer1_config.buffer.explorer_output = ExperienceBufferConfig(
153153
name="exp_buffer",
154-
storage_type=StorageType.QUEUE,
154+
storage_type=StorageType.QUEUE.value,
155155
)
156156
explorer1_config.check_and_update()
157157

@@ -255,7 +255,7 @@ def test_synchronizer(self):
255255
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
256256
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
257257
name="exp_buffer",
258-
storage_type=StorageType.QUEUE,
258+
storage_type=StorageType.QUEUE.value,
259259
)
260260
config.synchronizer.sync_method = self.sync_method
261261
config.synchronizer.sync_style = self.sync_style
@@ -275,7 +275,7 @@ def test_synchronizer(self):
275275
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
276276
explorer1_config.buffer.explorer_output = ExperienceBufferConfig(
277277
name="exp_buffer",
278-
storage_type=StorageType.QUEUE,
278+
storage_type=StorageType.QUEUE.value,
279279
)
280280
explorer2_config = deepcopy(explorer1_config)
281281
explorer2_config.explorer.name = "explorer2"
@@ -356,7 +356,7 @@ def test_synchronizer(self):
356356
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
357357
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
358358
name="exp_buffer",
359-
storage_type=StorageType.QUEUE,
359+
storage_type=StorageType.QUEUE.value,
360360
)
361361
config.synchronizer.sync_method = SyncMethod.NCCL
362362
config.synchronizer.sync_style = self.sync_style

0 commit comments

Comments
 (0)