|
4 | 4 | import torch |
5 | 5 |
|
6 | 6 | from tests.tools import RayUnittestBaseAysnc |
| 7 | +from trinity.buffer import get_buffer_reader |
7 | 8 | from trinity.buffer.reader.sql_reader import SQLReader |
8 | 9 | from trinity.buffer.writer.sql_writer import SQLWriter |
9 | | -from trinity.common.config import ExperienceBufferConfig |
| 10 | +from trinity.common.config import ExperienceBufferConfig, TasksetConfig |
10 | 11 | from trinity.common.constants import StorageType |
11 | 12 | from trinity.common.experience import Experience |
12 | 13 |
|
13 | 14 | db_path = os.path.join(os.path.dirname(__file__), "test.db") |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class TestSQLBuffer(RayUnittestBaseAysnc): |
17 | | - async def test_sql_buffer_read_write(self) -> None: |
| 18 | + async def test_sql_exp_buffer_read_write(self) -> None: |
18 | 19 | total_num = 8 |
19 | 20 | put_batch_size = 2 |
20 | 21 | read_batch_size = 4 |
21 | 22 | config = ExperienceBufferConfig( |
22 | 23 | name="test_buffer", |
23 | 24 | schema_type="experience", |
24 | 25 | path=f"sqlite:///{db_path}", |
25 | | - storage_type=StorageType.SQL, |
| 26 | + storage_type=StorageType.SQL.value, |
26 | 27 | batch_size=read_batch_size, |
27 | 28 | ) |
28 | 29 | sql_writer = SQLWriter(config.to_storage_config()) |
@@ -62,6 +63,41 @@ async def test_sql_buffer_read_write(self) -> None: |
62 | 63 | self.assertEqual(await sql_writer.release(), 0) |
63 | 64 | self.assertRaises(StopIteration, sql_reader.read) |
64 | 65 |
|
| 66 | + async def test_sql_task_buffer_read_write(self) -> None: |
| 67 | + total_samples = 8 |
| 68 | + batch_size = 4 |
| 69 | + config = TasksetConfig( |
| 70 | + name="test_task_buffer", |
| 71 | + path=f"sqlite:///{db_path}", |
| 72 | + storage_type=StorageType.SQL.value, |
| 73 | + batch_size=batch_size, |
| 74 | + default_workflow_type="math_workflow", |
| 75 | + ) |
| 76 | + sql_writer = SQLWriter(config.to_storage_config()) |
| 77 | + tasks = [ |
| 78 | + {"question": f"question_{i}", "answer": f"answer_{i}"} for i in range(total_samples) |
| 79 | + ] |
| 80 | + self.assertEqual(await sql_writer.acquire(), 1) |
| 81 | + sql_writer.write(tasks) |
| 82 | + sql_reader = get_buffer_reader(config.to_storage_config()) |
| 83 | + read_tasks = [] |
| 84 | + try: |
| 85 | + while True: |
| 86 | + cur_tasks = sql_reader.read() |
| 87 | + read_tasks.extend(cur_tasks) |
| 88 | + except StopIteration: |
| 89 | + pass |
| 90 | + self.assertEqual(len(read_tasks), total_samples) |
| 91 | + self.assertIn("question", read_tasks[0].raw_task) |
| 92 | + self.assertIn("answer", read_tasks[0].raw_task) |
| 93 | + db_wrapper = ray.get_actor("sql-test_task_buffer") |
| 94 | + self.assertIsNotNone(db_wrapper) |
| 95 | + self.assertEqual(await sql_writer.release(), 0) |
| 96 | + |
65 | 97 | def setUp(self) -> None: |
66 | 98 | if os.path.exists(db_path): |
67 | 99 | os.remove(db_path) |
| 100 | + |
| 101 | + def tearDown(self) -> None: |
| 102 | + if os.path.exists(db_path): |
| 103 | + os.remove(db_path) |
0 commit comments