From 5f582de2a92197a2dac17d968bf8f92daf62709c Mon Sep 17 00:00:00 2001 From: Luca Carminati Date: Tue, 25 Nov 2025 16:48:07 +0100 Subject: [PATCH 1/5] Fix ordering of sampled data in MultiSyncDataCollector --- torchrl/collectors/collectors.py | 44 ++++++++++---------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b7be73d243f..710982da8c4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3761,7 +3761,7 @@ def iterator(self) -> Iterator[TensorDictBase]: if cat_results is None: cat_results = "stack" - self.buffers = {} + self.buffers = [None for _ in range(self.num_workers)] dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] same_device = None @@ -3844,8 +3844,8 @@ def iterator(self) -> Iterator[TensorDictBase]: if preempt: # mask buffers if cat, and create a mask if stack if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): + buffers = [None] * self.num_workers + for worker_idx, buffer in enumerate(filter(None.__ne__, self.buffers)): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) @@ -3853,7 +3853,7 @@ def iterator(self) -> Iterator[TensorDictBase]: valid = valid.any(0) buffers[worker_idx] = buffer[..., valid] else: - for buffer in self.buffers.values(): + for buffer in filter(None.__ne__, self.buffers): with buffer.unlock_(): buffer.set( ("collector", "mask"), @@ -3886,7 +3886,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx in buffers.keys(): + for idx,buffer in enumerate(filter(None.__ne__, buffers)): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: @@ -3901,7 +3901,7 @@ def iterator(self) -> Iterator[TensorDictBase]: if same_device is None: prev_device = None same_device = True - for item in self.buffers.values(): + for item in filter(None.__ne__, self.buffers): if prev_device is None: prev_device = item.device else: @@ -3912,33 +3912,21 @@ def iterator(self) -> Iterator[TensorDictBase]: torch.stack if self._use_buffers else TensorDict.maybe_dense_stack ) if same_device: - self.out_buffer = stack(list(buffers.values()), 0) + self.out_buffer = stack([item for item in buffers if item is not None], 0) else: - self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 - ) + self.out_buffer = stack([item.cpu() for item in buffers if item is not None], 0) else: if self._use_buffers is None: - torchrl_logger.warning( - "use_buffer not specified and not yet inferred from data, assuming `True`." - ) + torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.") elif not self._use_buffers: - raise RuntimeError( - "Cannot concatenate results with use_buffers=False" - ) + raise RuntimeError("Cannot concatenate results with use_buffers=False") try: if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) + self.out_buffer = torch.cat([item for item in buffers if item is not None], cat_results) else: - self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results - ) + self.out_buffer = torch.cat([item.cpu() for item in buffers if item is not None], cat_results) except RuntimeError as err: - if ( - preempt - and cat_results != -1 - and "Sizes of tensors must match" in str(err) - ): + if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err): raise RuntimeError( "The value provided to cat_results isn't compatible with the collectors outputs. " "Consider using `cat_results=-1`." @@ -3956,11 +3944,7 @@ def iterator(self) -> Iterator[TensorDictBase]: self._frames += n_collected if self.postprocs: - self.postprocs = ( - self.postprocs.to(out.device) - if hasattr(self.postprocs, "to") - else self.postprocs - ) + self.postprocs = self.postprocs.to(out.device) if hasattr(self.postprocs, "to") else self.postprocs out = self.postprocs(out) if self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] From 73eed6c0060befe693df06933fd4832671377c10 Mon Sep 17 00:00:00 2001 From: Luca Carminati Date: Tue, 25 Nov 2025 16:48:07 +0100 Subject: [PATCH 2/5] Fix ordering of sampled data in MultiSyncDataCollector --- torchrl/collectors/collectors.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 710982da8c4..d0953e46f87 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3845,7 +3845,9 @@ def iterator(self) -> Iterator[TensorDictBase]: # mask buffers if cat, and create a mask if stack if cat_results != "stack": buffers = [None] * self.num_workers - for worker_idx, buffer in enumerate(filter(None.__ne__, self.buffers)): + for worker_idx, buffer in enumerate( + filter(None.__ne__, self.buffers) + ): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) @@ -3886,7 +3888,7 @@ def iterator(self) -> Iterator[TensorDictBase]: # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx,buffer in enumerate(filter(None.__ne__, buffers)): + for idx, buffer in enumerate(filter(None.__ne__, buffers)): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: @@ -3912,9 +3914,13 @@ def iterator(self) -> Iterator[TensorDictBase]: torch.stack if self._use_buffers else TensorDict.maybe_dense_stack ) if same_device: - self.out_buffer = stack([item for item in buffers if item is not None], 0) + self.out_buffer = stack( + [item for item in buffers if item is not None], 0 + ) else: - self.out_buffer = stack([item.cpu() for item in buffers if item is not None], 0) + self.out_buffer = stack( + [item.cpu() for item in buffers if item is not None], 0 + ) else: if self._use_buffers is None: torchrl_logger.warning("use_buffer not specified and not yet inferred from data, assuming `True`.") @@ -3922,9 +3928,14 @@ def iterator(self) -> Iterator[TensorDictBase]: raise RuntimeError("Cannot concatenate results with use_buffers=False") try: if same_device: - self.out_buffer = torch.cat([item for item in buffers if item is not None], cat_results) + self.out_buffer = torch.cat( + [item for item in buffers if item is not None], cat_results + ) else: - self.out_buffer = torch.cat([item.cpu() for item in buffers if item is not None], cat_results) + self.out_buffer = torch.cat( + [item.cpu() for item in buffers if item is not None], + cat_results, + ) except RuntimeError as err: if preempt and cat_results != -1 and "Sizes of tensors must match" in str(err): raise RuntimeError( From bb5fad6bff36dc6d09b541e9c329e0550263bb11 Mon Sep 17 00:00:00 2001 From: Luca Carminati Date: Fri, 28 Nov 2025 10:20:47 +0100 Subject: [PATCH 3/5] Fix tests --- test/test_collector.py | 133 +++++++++++++++++++++++++++++++ torchrl/collectors/collectors.py | 21 ++--- 2 files changed, 138 insertions(+), 16 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..c0058068ac2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1717,6 +1717,139 @@ def env_fn(): total_frames=frames_per_batch * 100, ) + class FixedIDEnv(EnvBase): + """ + A simple mock environment that returns a fixed ID as its sole observation. + + This environment is designed to test MultiSyncDataCollector ordering. + Each environment instance is initialized with a unique env_id, which it + returns as the observation at every step. + """ + + def __init__(self, env_id: int, max_steps: int = 10, **kwargs): + """ + Args: + env_id: The ID to return as observation. This will be returned as a tensor. + max_steps: Maximum number of steps before the environment terminates. + """ + super().__init__(device="cpu", batch_size=torch.Size([])) + self.env_id = env_id + self.max_steps = max_steps + self._step_count = 0 + + # Define specs + self.observation_spec = Composite( + observation=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.action_spec = Composite( + action=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.reward_spec = Composite( + reward=Unbounded(shape=(1,), dtype=torch.float32) + ) + self.done_spec = Composite( + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), + truncated=Unbounded(shape=(1,), dtype=torch.bool), + ) + + def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict: + """Reset the environment and return initial observation.""" + # Add random sleep to simulate real-world timing variations + # This helps test that the collector properly handles different reset times + time.sleep(torch.rand(1).item() * 0.01) # Random sleep up to 10ms + + self._step_count = 0 + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "done": torch.tensor([False], dtype=torch.bool), + "terminated": torch.tensor([False], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _step(self, tensordict: TensorDict) -> TensorDict: + """Execute one step and return the env_id as observation.""" + self._step_count += 1 + done = self._step_count >= self.max_steps + + return TensorDict( + { + "observation": torch.tensor( + [float(self.env_id)], dtype=torch.float32 + ), + "reward": torch.tensor([1.0], dtype=torch.float32), + "done": torch.tensor([done], dtype=torch.bool), + "terminated": torch.tensor([done], dtype=torch.bool), + "truncated": torch.tensor([False], dtype=torch.bool), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed: int | None) -> int | None: + """Set the seed for reproducibility.""" + if seed is not None: + torch.manual_seed(seed) + return seed + + @pytest.mark.parametrize("num_envs", [8]) + def test_multi_sync_data_collector_ordering(self, num_envs: int): + """ + Test that MultiSyncDataCollector returns data in the correct order. + + We create num_envs environments, each returning its env_id as the observation. + After collection, we verify that the observations correspond to the correct env_ids in order + """ + frames_per_batch = num_envs * 5 # Collect 5 steps per environment + + # Create environment factories using partial - one for each env_id + # This pattern mirrors CrossPlayEvaluator._rollout usage + env_factories = [ + functools.partial(self.FixedIDEnv, env_id=i, max_steps=10) + for i in range(num_envs) + ] + + # Create policy factories using partial + policy = ParametricPolicy() + + # Initialize MultiSyncDataCollector + collector = MultiSyncDataCollector( + create_env_fn=env_factories, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=frames_per_batch, + device="cpu", + ) + + # Collect one batch + for batch in collector: + # Verify that each environment's observations match its env_id + # batch has shape [num_envs, frames_per_env] + for env_idx in range(num_envs): + env_data = batch[env_idx] + observations = env_data["observation"] + + # All observations from this environment should equal its env_id + expected_id = float(env_idx) + actual_ids = observations.flatten().unique() + + assert len(actual_ids) == 1, ( + f"Env {env_idx} should only produce observations with value {expected_id}, " + f"but got {actual_ids.tolist()}" + ) + assert ( + actual_ids[0].item() == expected_id + ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" + + # Only process the first batch + break + + collector.shutdown() + class TestCollectorDevices: class DeviceLessEnv(EnvBase): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d0953e46f87..738588eac48 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -3760,7 +3760,6 @@ def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: cat_results = "stack" - self.buffers = [None for _ in range(self.num_workers)] dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] @@ -3781,7 +3780,6 @@ def iterator(self) -> Iterator[TensorDictBase]: msg = "continue_random" else: msg = "continue" - # Debug: sending 'continue' self.pipes[idx].send((None, msg)) self._iter += 1 @@ -3845,15 +3843,13 @@ def iterator(self) -> Iterator[TensorDictBase]: # mask buffers if cat, and create a mask if stack if cat_results != "stack": buffers = [None] * self.num_workers - for worker_idx, buffer in enumerate( - filter(None.__ne__, self.buffers) - ): + for idx, buffer in enumerate(filter(None.__ne__, self.buffers)): valid = buffer.get(("collector", "traj_ids")) != -1 if valid.ndim > 2: valid = valid.flatten(0, -2) if valid.ndim == 2: valid = valid.any(0) - buffers[worker_idx] = buffer[..., valid] + buffers[idx] = buffer[..., valid] else: for buffer in filter(None.__ne__, self.buffers): with buffer.unlock_(): @@ -3865,11 +3861,6 @@ def iterator(self) -> Iterator[TensorDictBase]: else: buffers = self.buffers - # Skip frame counting if this worker didn't send data this iteration - # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: - continue - workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() if workers_frames[idx] >= self.total_frames: @@ -3878,17 +3869,15 @@ def iterator(self) -> Iterator[TensorDictBase]: if self.replay_buffer is not None: yield self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] + self.frames_per_batch_worker(worker_idx) + for worker_idx in range(self.num_workers) ) continue # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 - for idx, buffer in enumerate(filter(None.__ne__, buffers)): + for idx in range(self.num_workers): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) if preempt: From c428fd461202c66212ebb20b9fe2e3ea36544b93 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 1 Dec 2025 21:42:58 +0000 Subject: [PATCH 4/5] empty From 0e410ae97fc42b4c7337e66cc9ac524ebe7ebde8 Mon Sep 17 00:00:00 2001 From: Luca Carminati Date: Mon, 15 Dec 2025 17:37:41 +0100 Subject: [PATCH 5/5] Add preemption test --- test/test_collector.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 51f5ceab483..817d26b4637 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1721,7 +1721,7 @@ class FixedIDEnv(EnvBase): returns as the observation at every step. """ - def __init__(self, env_id: int, max_steps: int = 10, **kwargs): + def __init__(self, env_id: int, max_steps: int = 10, sleep_odd_only: bool = False, **kwargs): """ Args: env_id: The ID to return as observation. This will be returned as a tensor. @@ -1730,6 +1730,7 @@ def __init__(self, env_id: int, max_steps: int = 10, **kwargs): super().__init__(device="cpu", batch_size=torch.Size([])) self.env_id = env_id self.max_steps = max_steps + self.sleep_odd_only = sleep_odd_only self._step_count = 0 # Define specs @@ -1750,9 +1751,13 @@ def __init__(self, env_id: int, max_steps: int = 10, **kwargs): def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict: """Reset the environment and return initial observation.""" - # Add random sleep to simulate real-world timing variations + # Add sleep to simulate real-world timing variations # This helps test that the collector properly handles different reset times - time.sleep(torch.rand(1).item() * 0.01) # Random sleep up to 10ms + if not self.sleep_odd_only: + # Random sleep up to 10ms + time.sleep(torch.rand(1).item() * 0.01) + elif self.env_id % 2 == 0: + time.sleep(0.01 + torch.rand(1).item() * 0.001) self._step_count = 0 return TensorDict( @@ -1792,19 +1797,23 @@ def _set_seed(self, seed: int | None) -> int | None: return seed @pytest.mark.parametrize("num_envs", [8]) - def test_multi_sync_data_collector_ordering(self, num_envs: int): + @pytest.mark.parametrize("with_preempt", [False, True]) + def test_multi_sync_data_collector_ordering(self, num_envs: int, with_preempt: bool): """ Test that MultiSyncDataCollector returns data in the correct order. We create num_envs environments, each returning its env_id as the observation. After collection, we verify that the observations correspond to the correct env_ids in order """ + if with_preempt and IS_OSX: + pytest.skip("Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform.") + frames_per_batch = num_envs * 5 # Collect 5 steps per environment # Create environment factories using partial - one for each env_id # This pattern mirrors CrossPlayEvaluator._rollout usage env_factories = [ - functools.partial(self.FixedIDEnv, env_id=i, max_steps=10) + functools.partial(self.FixedIDEnv, env_id=i, max_steps=10, sleep_odd_only=with_preempt) for i in range(num_envs) ] @@ -1818,14 +1827,16 @@ def test_multi_sync_data_collector_ordering(self, num_envs: int): frames_per_batch=frames_per_batch, total_frames=frames_per_batch, device="cpu", + preemptive_threshold=0.5 if with_preempt else None ) # Collect one batch for batch in collector: # Verify that each environment's observations match its env_id # batch has shape [num_envs, frames_per_env] - for env_idx in range(num_envs): - env_data = batch[env_idx] + # In the pre-emption case, we have slow odd envs. These should be skipped by pre-emption + for i, env_idx in enumerate(range(0, num_envs, 2 if with_preempt else 1)): + env_data = batch[i] observations = env_data["observation"] # All observations from this environment should equal its env_id