Skip to content

Commit 1e323dd

Browse files
authored
Merge branch 'develop' into fix/bytes-and-strings
2 parents 409e423 + 9c652b1 commit 1e323dd

File tree

2 files changed

+109
-13
lines changed

2 files changed

+109
-13
lines changed

CPAC/utils/monitoring/monitoring.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def isoformat(self) -> str:
8484

8585

8686
class DatetimeWithSafeNone(datetime, _NoTime):
87-
"""Time class that can be None or a time value.
87+
r"""Time class that can be None or a time value.
8888
8989
Examples
9090
--------
@@ -93,9 +93,9 @@ class DatetimeWithSafeNone(datetime, _NoTime):
9393
'2025-06-18T21:06:43.730004'
9494
>>> DatetimeWithSafeNone("2025-06-18T21:06:43.730004").isoformat()
9595
'2025-06-18T21:06:43.730004'
96-
>>> DatetimeWithSafeNone(b"\\x07\\xe9\\x06\\x12\\x10\\x18\\x1c\\x88\\x6d\\x01").isoformat()
96+
>>> DatetimeWithSafeNone(b"\x07\xe9\x06\x12\x10\x18\x1c\x88\x6d\x01").isoformat()
9797
'2025-06-18T16:24:28.028040+00:00'
98-
>>> DatetimeWithSafeNone(b'\\x07\\xe9\\x06\\x12\\x10\\x18\\x1c\\x88m\\x00').isoformat()
98+
>>> DatetimeWithSafeNone(b'\x07\xe9\x06\x12\x10\x18\x1c\x88m\x00').isoformat()
9999
'2025-06-18T16:24:28.028040'
100100
>>> DatetimeWithSafeNone(DatetimeWithSafeNone("2025-06-18")).isoformat()
101101
'2025-06-18T00:00:00'
@@ -148,6 +148,7 @@ def __new__(
148148
fold: Optional[int] = 0,
149149
) -> "DatetimeWithSafeNone | _NoTime":
150150
"""Create a new instance of the class."""
151+
# First check if all arguments are provided as integers
151152
if (
152153
isinstance(year, int)
153154
and isinstance(month, int)
@@ -170,10 +171,13 @@ def __new__(
170171
tzinfo,
171172
fold=fold,
172173
)
173-
else:
174-
dt = year
174+
175+
# Otherwise, year contains the datetime-like object
176+
dt = year
177+
175178
if dt is None:
176179
return NoTime
180+
177181
if isinstance(dt, datetime):
178182
return datetime.__new__(
179183
cls,
@@ -186,6 +190,7 @@ def __new__(
186190
dt.microsecond,
187191
dt.tzinfo,
188192
)
193+
189194
if isinstance(dt, bytes):
190195
try:
191196
tzflag: Optional[int]
@@ -219,25 +224,77 @@ def __new__(
219224
return datetime.__new__(
220225
cls, year, month, day, hour, minute, second, microsecond, tzinfo
221226
)
222-
else:
223-
msg = f"Unexpected type: {[type(part) for part in [year, month, day, hour, minute, second, microsecond]]}"
224-
raise TypeError(msg)
225-
except UnicodeDecodeError:
226-
error = f"Cannot decode bytes to string: {dt!r}"
227+
msg = f"Unexpected type: {[type(part) for part in [year, month, day, hour, minute, second, microsecond]]}"
228+
raise TypeError(msg)
229+
except (struct.error, IndexError) as e:
230+
error = f"Cannot unpack bytes to datetime: {dt!r} - {e}"
227231
raise TypeError(error)
232+
228233
if isinstance(dt, str):
229234
try:
230235
return DatetimeWithSafeNone(datetime.fromisoformat(dt))
231236
except (ValueError, TypeError):
232237
error = f"Invalid ISO-format datetime string: {dt}"
233-
else:
234-
error = f"Cannot convert {type(dt)} to datetime"
238+
raise TypeError(error)
239+
240+
error = f"Cannot convert {type(dt)} to datetime"
235241
raise TypeError(error)
236242

237243
def __bool__(self) -> bool:
238244
"""Return True if not NoTime."""
239245
return self is not NoTime
240246

247+
def __eq__(self, other: object) -> bool:
248+
"""Compare DatetimeWithSafeNone instances with tzinfo-aware logic.
249+
250+
If only one side has tzinfo, consider them equal if all other components match.
251+
"""
252+
if self is NoTime and other is NoTime:
253+
return True
254+
if self is NoTime or other is NoTime:
255+
return False
256+
if not isinstance(other, (datetime, DatetimeWithSafeNone)):
257+
return False
258+
259+
# Compare all datetime components except tzinfo
260+
components_match = (
261+
self.year == other.year
262+
and self.month == other.month
263+
and self.day == other.day
264+
and self.hour == other.hour
265+
and self.minute == other.minute
266+
and self.second == other.second
267+
and self.microsecond == other.microsecond
268+
)
269+
270+
if not components_match:
271+
return False
272+
273+
# If components match, check tzinfo:
274+
# - If either has None tzinfo, consider them equal
275+
# - If both have tzinfo, they must match
276+
if self.tzinfo is None or other.tzinfo is None:
277+
return True
278+
279+
return self.tzinfo == other.tzinfo
280+
281+
def __hash__(self) -> int:
282+
"""Return hash based on datetime components, ignoring tzinfo."""
283+
if self is NoTime:
284+
return hash(NoTime)
285+
# Hash based on datetime components only, not tzinfo
286+
return hash(
287+
(
288+
self.year,
289+
self.month,
290+
self.day,
291+
self.hour,
292+
self.minute,
293+
self.second,
294+
self.microsecond,
295+
)
296+
)
297+
241298
def __sub__(self, other: "DatetimeWithSafeNone | _NoTime") -> datetime | timedelta: # type: ignore[reportIncompatibleMethodOverride]
242299
"""Subtract between a datetime or timedelta or None."""
243300
return _safe_none_diff(self, other)

CPAC/utils/tests/test_utils.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Tests of CPAC utility functions."""
1818

19-
from datetime import datetime, timedelta
19+
from copy import deepcopy
20+
from datetime import datetime, timedelta, timezone
2021
import multiprocessing
2122
from unittest import mock
2223

@@ -240,3 +241,41 @@ def test_datetime_with_safe_none(t1: OptionalDatetime, t2: OptionalDatetime):
240241
assert isinstance(t2 - t1, timedelta)
241242
else:
242243
assert t2 - t1 == timedelta(0)
244+
245+
246+
def test_deepcopy_datetimewithsafenone_raises_error() -> None:
247+
"""Test bytestring TypeError during deepcopy operation."""
248+
# Create a node dictionary similar to what's used in the Gantt chart generation
249+
node = {
250+
"id": "test_node",
251+
"hash": "abc123",
252+
"start": DatetimeWithSafeNone(
253+
datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
254+
),
255+
"finish": DatetimeWithSafeNone(
256+
datetime(2024, 1, 1, 11, 30, 0, tzinfo=timezone.utc)
257+
),
258+
"runtime_threads": 4,
259+
"runtime_memory_gb": 2.5,
260+
"estimated_memory_gb": 3.0,
261+
"num_threads": 4,
262+
}
263+
264+
# This should raise: TypeError: Cannot convert <class 'bytes'> to datetime
265+
# with the original code because deepcopy pickles DatetimeWithSafeNone objects
266+
# as bytes, and the __new__ method doesn't properly handle the pickle protocol
267+
finish_node = deepcopy(node)
268+
269+
assert finish_node["start"] == node["start"]
270+
assert finish_node["finish"] == node["finish"]
271+
272+
273+
def test_deepcopy_datetimewithsafenone_direct():
274+
"""Test deepcopy directly on DatetimeWithSafeNone instance."""
275+
dt = DatetimeWithSafeNone(datetime(2024, 1, 1, 10, 0, 0, tzinfo=timezone.utc))
276+
277+
# This triggers the pickle/unpickle cycle which passes bytes to __new__
278+
dt_copy = deepcopy(dt)
279+
280+
assert dt_copy == dt
281+
assert isinstance(dt_copy, DatetimeWithSafeNone)

0 commit comments

Comments
 (0)