Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.reader import READER
from trinity.common.config import ExperienceBufferConfig, StorageConfig, TasksetConfig
from trinity.common.constants import StorageType

Expand All @@ -16,28 +17,10 @@ def get_buffer_reader(config: BufferStorageConfig) -> BufferReader:
storage_config: StorageConfig = config.to_storage_config()
else:
storage_config = config
if storage_config.storage_type == StorageType.SQL:
from trinity.buffer.reader.sql_reader import SQLReader

return SQLReader(storage_config)
elif storage_config.storage_type == StorageType.QUEUE:
from trinity.buffer.reader.queue_reader import QueueReader

return QueueReader(storage_config)
elif storage_config.storage_type == StorageType.FILE:
from trinity.buffer.reader.file_reader import (
ExperienceFileReader,
TaskFileReader,
)

schema_type = storage_config.schema_type
if schema_type:
# only trainer input has schema type
return ExperienceFileReader(storage_config)
else:
return TaskFileReader(storage_config)
else:
reader_cls = READER.get(storage_config.storage_type.value)
if reader_cls is None:
raise ValueError(f"{storage_config.storage_type} not supported.")
return reader_cls(storage_config)


def get_buffer_writer(config: BufferStorageConfig) -> BufferWriter:
Expand Down
6 changes: 6 additions & 0 deletions trinity/buffer/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from trinity.buffer.reader.file_reader import FileReader
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.reader.reader import READER
from trinity.buffer.reader.sql_reader import SQLReader

__all__ = ["READER", "FileReader", "QueueReader", "SQLReader"]
23 changes: 23 additions & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datasets import Dataset, load_dataset

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.reader.reader import READER
from trinity.buffer.schema.formatter import FORMATTER
from trinity.common.config import StorageConfig

Expand Down Expand Up @@ -95,6 +96,28 @@ async def read_async(self, batch_size: Optional[int] = None):
raise StopAsyncIteration from e


@READER.register_module("file")
class FileReader(BaseFileReader):
"""Provide a unified interface for Experience and Task file readers."""

def __init__(self, config: StorageConfig):
if config.schema_type:
self.reader = ExperienceFileReader(config)
else:
self.reader = TaskFileReader(config)

def read(self, batch_size: Optional[int] = None) -> List:
return self.reader.read(batch_size)

def read_with_indices(self, indices: List[int]) -> List:
"""Read tasks with indices."""
return self.reader.read_with_indices(indices)

async def read_with_indices_async(self, indices: List[int]) -> List:
"""Read tasks with indices asynchronously."""
return await self.reader.read_with_indices_async(indices)


class ExperienceFileReader(BaseFileReader):
"""Reader for SFT / DPO file data."""

Expand Down
2 changes: 2 additions & 0 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import ray

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.reader.reader import READER
from trinity.buffer.storage.queue import QueueStorage
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType


@READER.register_module("queue")
class QueueReader(BufferReader):
"""Reader of the Queue buffer."""

Expand Down
3 changes: 3 additions & 0 deletions trinity/buffer/reader/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from trinity.utils.registry import Registry

READER = Registry("reader")
2 changes: 2 additions & 0 deletions trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import ray

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.reader.reader import READER
from trinity.buffer.storage.sql import SQLStorage
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType


@READER.register_module("sql")
class SQLReader(BufferReader):
"""Reader of the SQL buffer."""

Expand Down
73 changes: 71 additions & 2 deletions trinity/buffer/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,77 @@
from trinity.utils.annotations import Experimental


def get_taskset_scheduler(explorer_state: Dict, config: Config) -> "TasksetSchedulerBase":
"""Get a taskset scheduler according to the config.

Args:
explorer_state (Dict): Restoration state from checkpoint (may include progress info)
config (Config): Full system configuration containing buffer and taskset settings

Returns:
TasksetSchedulerBase: The taskset scheduler instance
"""
taskset_configs = config.buffer.explorer_input.tasksets
if len(taskset_configs) == 1 and taskset_configs[0].task_selector.selector_type == "sequential":
return SimpleTasksetScheduler(explorer_state, config)
else:
return TasksetScheduler(explorer_state, config)


class TasksetSchedulerBase:
def __init__(self, explorer_state: Dict, config: Config):
self.config = config
self.explorer_state = explorer_state

async def read_async(self) -> List:
"""Asynchronously reads a batch of tasks according to the current schedule."""
raise NotImplementedError

def state_dict(self) -> List[Dict]:
"""return persistent state for checkpointing.

Returns:
List[Dict]: State dicts for all selectors (one per taskset)
"""
raise NotImplementedError

def update(self, pipeline_metrics: Dict) -> None:
"""Update selectors using feedback from the training pipeline."""
raise NotImplementedError


class SimpleTasksetScheduler(TasksetSchedulerBase):
"""
A simple taskset scheduler that only reads from one taskset without task selection strategies.
"""

def __init__(self, explorer_state: Dict, config: Config):
super().__init__(explorer_state, config)
if "latest_task_index" in self.explorer_state:
self.explorer_state["taskset_states"] = [
{
"current_index": explorer_state["latest_task_index"],
}
]
index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get(
"current_index", 0
)
self.config.buffer.explorer_input.tasksets[0].index = index
self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0])

async def read_async(self) -> List:
return self.tasksets.read_async

def state_dict(self) -> List[Dict]:
return [{"current_index": 0}]

def update(self, pipeline_metrics: Dict) -> None:
# do nothing here
return


@Experimental
class TasksetScheduler:
class TasksetScheduler(TasksetSchedulerBase):
"""
Coordinates multiple datasets (tasksets) with customizable task selection strategies per taskset.

Expand All @@ -38,7 +107,7 @@ def __init__(self, explorer_state: Dict, config: Config):
explorer_state (Dict): Restoration state from checkpoint (may include progress info)
config (Config): Full system configuration containing buffer and taskset settings
"""
self.config = config
super().__init__(explorer_state, config)

# Backward compatibility: old format stored 'latest_task_index' directly
if "latest_task_index" in explorer_state:
Expand Down