Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_sql_storage(self, schema_type):
config = ExperienceBufferConfig(
name="test_storage",
schema_type=schema_type,
storage_type=StorageType.SQL,
storage_type=StorageType.SQL.value,
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
Expand Down Expand Up @@ -91,7 +91,7 @@ async def test_sql_experience_buffer(self):
config = ExperienceBufferConfig(
name="test_storage",
schema_type="experience",
storage_type=StorageType.SQL,
storage_type=StorageType.SQL.value,
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setUp(self):
dataset_config = get_unittest_dataset_config("countdown", "train")
self.config.buffer.explorer_input.taskset = dataset_config
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="test_buffer", storage_type=StorageType.FILE
name="test_buffer", storage_type=StorageType.FILE.value
)
self.config.check_and_update()
ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()})
Expand Down
10 changes: 5 additions & 5 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def test_queue_buffer(self, name, use_priority_queue):
config = ExperienceBufferConfig(
name=name,
schema_type="experience",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_read_timeout=3,
path=BUFFER_FILE_PATH,
batch_size=self.train_batch_size,
Expand Down Expand Up @@ -100,7 +100,7 @@ async def test_priority_queue_capacity(self):
config = ExperienceBufferConfig(
name="test_buffer_small",
schema_type="experience",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_read_timeout=1,
capacity=8,
path=BUFFER_FILE_PATH,
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_queue_buffer_capacity(self):
config = ExperienceBufferConfig(
name="test_buffer_small",
schema_type="experience",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_read_timeout=3,
capacity=4,
path=BUFFER_FILE_PATH,
Expand Down Expand Up @@ -191,7 +191,7 @@ async def test_priority_queue_buffer_reuse(self):
config = ExperienceBufferConfig(
name="test_buffer_small",
schema_type="experience",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_read_timeout=3,
capacity=4, # max total number of items; each item is List[Experience]
path=BUFFER_FILE_PATH,
Expand Down Expand Up @@ -320,7 +320,7 @@ async def test_priority_queue_reuse_count_control(self):
config = ExperienceBufferConfig(
name="test_buffer_small",
schema_type="experience",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_read_timeout=3,
capacity=4, # max total number of items; each item is List[Experience]
path=BUFFER_FILE_PATH,
Expand Down
28 changes: 28 additions & 0 deletions tests/buffer/reader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from tests.tools import RayUnittestBaseAysnc, get_unittest_dataset_config
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import FileReader, TaskFileReader


@READER.register_module("custom")
class CustomReader(TaskFileReader):
"""A custom reader for testing."""

def __init__(self, config):
super().__init__(config)


class TestBufferReader(RayUnittestBaseAysnc):
async def test_buffer_reader_registration(self) -> None:
config = get_unittest_dataset_config("countdown", "train")
config.batch_size = 2
config.storage_type = "custom"
reader = get_buffer_reader(config)
self.assertIsInstance(reader, CustomReader)
tasks = await reader.read_async()
self.assertEqual(len(tasks), 2)
config.storage_type = "file"
reader = get_buffer_reader(config)
self.assertIsInstance(reader, FileReader)
tasks = await reader.read_async()
self.assertEqual(len(tasks), 2)
42 changes: 39 additions & 3 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@
import torch

from tests.tools import RayUnittestBaseAysnc
from trinity.buffer import get_buffer_reader
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import ExperienceBufferConfig
from trinity.common.config import ExperienceBufferConfig, TasksetConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience

db_path = os.path.join(os.path.dirname(__file__), "test.db")


class TestSQLBuffer(RayUnittestBaseAysnc):
async def test_sql_buffer_read_write(self) -> None:
async def test_sql_exp_buffer_read_write(self) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
config = ExperienceBufferConfig(
name="test_buffer",
schema_type="experience",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
storage_type=StorageType.SQL.value,
batch_size=read_batch_size,
)
sql_writer = SQLWriter(config.to_storage_config())
Expand Down Expand Up @@ -62,6 +63,41 @@ async def test_sql_buffer_read_write(self) -> None:
self.assertEqual(await sql_writer.release(), 0)
self.assertRaises(StopIteration, sql_reader.read)

async def test_sql_task_buffer_read_write(self) -> None:
total_samples = 8
batch_size = 4
config = TasksetConfig(
name="test_task_buffer",
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL.value,
batch_size=batch_size,
default_workflow_type="math_workflow",
)
sql_writer = SQLWriter(config.to_storage_config())
tasks = [
{"question": f"question_{i}", "answer": f"answer_{i}"} for i in range(total_samples)
]
self.assertEqual(await sql_writer.acquire(), 1)
sql_writer.write(tasks)
sql_reader = get_buffer_reader(config.to_storage_config())
read_tasks = []
try:
while True:
cur_tasks = sql_reader.read()
read_tasks.extend(cur_tasks)
except StopIteration:
pass
self.assertEqual(len(read_tasks), total_samples)
self.assertIn("question", read_tasks[0].raw_task)
self.assertIn("answer", read_tasks[0].raw_task)
db_wrapper = ray.get_actor("sql-test_task_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(await sql_writer.release(), 0)

def setUp(self) -> None:
if os.path.exists(db_path):
os.remove(db_path)

def tearDown(self) -> None:
if os.path.exists(db_path):
os.remove(db_path)
47 changes: 45 additions & 2 deletions tests/buffer/task_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@

from parameterized import parameterized

from tests.tools import get_template_config
from trinity.buffer.task_scheduler import TasksetScheduler
from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.buffer.reader import READER
from trinity.buffer.reader.file_reader import TaskFileReader
from trinity.buffer.task_scheduler import TasksetScheduler, get_taskset_scheduler
from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig
from trinity.common.workflows.workflow import Task


@READER.register_module("custom_reader")
class CustomReader(TaskFileReader):
"""A custom reader for testing."""

def __init__(self, config):
super().__init__(config)


class TestTaskScheduler(unittest.IsolatedAsyncioTestCase):
temp_output_path = "tmp/test_task_scheduler/"

Expand Down Expand Up @@ -313,3 +323,36 @@ async def test_task_scheduler(

with self.assertRaises(StopAsyncIteration):
batch_tasks = await task_scheduler.read_async()

async def test_task_scheduler_simple(self):
config = get_template_config()
config.mode = "explore"
config.buffer.batch_size = 4
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown", "train")
config.buffer.explorer_input.taskset.storage_type = "custom_reader"
config.check_and_update()

task_scheduler = get_taskset_scheduler({}, config)

batch_tasks = await task_scheduler.read_async()
self.assertEqual(len(batch_tasks), 4)
task_scheduler_state = task_scheduler.state_dict()
self.assertEqual(len(task_scheduler_state), 1)
self.assertEqual(task_scheduler_state[0]["current_index"], 4)
# no effect
task_scheduler.update({"metric1": 0.5})

task_scheduler = get_taskset_scheduler(
{
"latest_iteration": 1,
"taskset_states": [
{"current_index": 8},
],
},
config,
)
batch_tasks = await task_scheduler.read_async()
self.assertEqual(len(batch_tasks), 4)
task_scheduler_state = task_scheduler.state_dict()
self.assertEqual(len(task_scheduler_state), 1)
self.assertEqual(task_scheduler_state[0]["current_index"], 12)
14 changes: 7 additions & 7 deletions tests/buffer/task_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
class TaskStorageTest(RayUnittestBase):
@parameterized.expand(
[
(StorageType.FILE, True, 2),
(StorageType.SQL, True, 2),
(StorageType.FILE, False, 0),
(StorageType.SQL, False, 0),
(StorageType.FILE, False, 2),
(StorageType.SQL, False, 2),
(StorageType.FILE.value, True, 2),
(StorageType.SQL.value, True, 2),
(StorageType.FILE.value, False, 0),
(StorageType.SQL.value, False, 0),
(StorageType.FILE.value, False, 2),
(StorageType.SQL.value, False, 2),
]
)
def test_read_task(self, storage_type, is_eval, offset):
Expand All @@ -37,7 +37,7 @@ def test_read_task(self, storage_type, is_eval, offset):
config.buffer.explorer_input.taskset.is_eval = is_eval
config.buffer.explorer_input.taskset.index = offset
config.buffer.explorer_input.taskset.batch_size = batch_size
if storage_type == StorageType.SQL:
if storage_type == StorageType.SQL.value:
dataset = datasets.load_dataset(
config.buffer.explorer_input.taskset.path, split="train"
)
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def setUp(self):
self.config.explorer.service_status_check_interval = 30
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="experience_buffer",
storage_type=StorageType.SQL,
storage_type=StorageType.SQL.value,
)
self.config.check_and_update()
if multiprocessing.get_start_method(allow_none=True) != "spawn":
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def setUp(self):
self.config.buffer.trainer_input.experience_buffer
) = ExperienceBufferConfig(
name="test",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
schema_type="experience",
path="",
)
Expand Down
10 changes: 5 additions & 5 deletions tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_synchronizer(self):
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
Expand All @@ -151,7 +151,7 @@ def test_synchronizer(self):
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.buffer.explorer_output = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
explorer1_config.check_and_update()

Expand Down Expand Up @@ -255,7 +255,7 @@ def test_synchronizer(self):
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
config.synchronizer.sync_method = self.sync_method
config.synchronizer.sync_style = self.sync_style
Expand All @@ -275,7 +275,7 @@ def test_synchronizer(self):
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.buffer.explorer_output = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
explorer2_config = deepcopy(explorer1_config)
explorer2_config.explorer.name = "explorer2"
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_synchronizer(self):
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
config.synchronizer.sync_method = SyncMethod.NCCL
config.synchronizer.sync_style = self.sync_style
Expand Down
5 changes: 5 additions & 0 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CHECKPOINT_ROOT_DIR_ENV_VAR,
MODEL_PATH_ENV_VAR,
PromptType,
StorageType,
)

API_MODEL_PATH_ENV_VAR = "TRINITY_API_MODEL_PATH"
Expand Down Expand Up @@ -133,6 +134,7 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
split="train",
storage_type=StorageType.FILE.value,
schema_type="sft",
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
Expand All @@ -146,6 +148,7 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"),
split="train",
storage_type=StorageType.FILE.value,
format=FormatConfig(
prompt_type=PromptType.MESSAGES,
messages_key="messages",
Expand All @@ -159,6 +162,8 @@ def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "t
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "human_like"),
split="train",
storage_type=StorageType.FILE.value,
schema_type="dpo",
format=FormatConfig(
prompt_type=PromptType.PLAINTEXT,
prompt_key="prompt",
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_trainer(self, mock_load):
experience_buffer=ExperienceBufferConfig(
name="test_queue_storage",
max_read_timeout=20,
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
max_retry_times=10,
)
),
Expand Down Expand Up @@ -512,7 +512,7 @@ def test_fully_async_mode(self):
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
config.buffer.trainer_input.experience_buffer.replay_buffer.enable = self.use_priority_queue
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
Expand Down Expand Up @@ -541,7 +541,7 @@ def test_fully_async_mode(self):
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
storage_type=StorageType.QUEUE.value,
)
explorer2_config = deepcopy(explorer1_config)
explorer2_config.trainer = deepcopy(trainer_config.trainer)
Expand Down
Loading