diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index b9579c8aa..73818b0ad 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -3,6 +3,7 @@ import asyncio import os +import urllib.parse from typing import Generic, List, Sequence, TypeVar import jax @@ -48,7 +49,7 @@ class TreeStore(Generic[T]): tree: PyTree[JaggedArrayStore] def __init__(self, tree, path: str, mode: str): - self.path = path + self.path = _normalize_cache_dir(path) self.mode = mode self.tree = tree @@ -61,8 +62,9 @@ def open(exemplar: T, path: str, *, mode="a", cache_metadata: bool = False) -> " """ Open a TreeStoreBuilder from a file. """ - tree = _construct_builder_tree(exemplar, path, mode, cache_metadata) - return TreeStore(tree, path, mode) + resolved_path = _normalize_cache_dir(path) + tree = _construct_builder_tree(exemplar, resolved_path, mode, cache_metadata) + return TreeStore(tree, resolved_path, mode) def append(self, ex: T): return self.extend([ex]) @@ -179,6 +181,31 @@ async def async_len(self) -> int: return await jax.tree.leaves(self.tree)[0].num_rows_async() +def _normalize_cache_dir(path: os.PathLike[str] | str) -> str: + """Resolve relative file paths to absolute paths while leaving URLs untouched.""" + + path_str = os.fspath(path) + + if _is_probably_url(path_str): + return path_str + + expanded = os.path.expanduser(path_str) + return os.path.abspath(expanded) + + +def _is_probably_url(path: str) -> bool: + parsed = urllib.parse.urlparse(path) + + if parsed.scheme in ("", "file"): + return False + + # Handle Windows drive letters like "C:\\" which urllib parses as a scheme. + if len(parsed.scheme) == 1 and path[1:2] == ":": + return False + + return True + + def _construct_builder_tree(exemplar, path, mode, cache_metadata): def open_builder(tree_path, item): item = np.asarray(item) diff --git a/tests/test_tree_store.py b/tests/test_tree_store.py index 0b750129a..a0a3e7600 100644 --- a/tests/test_tree_store.py +++ b/tests/test_tree_store.py @@ -1,6 +1,7 @@ # Copyright 2025 The Levanter Authors # SPDX-License-Identifier: Apache-2.0 +import os import tempfile from typing import Any, Dict, Iterator, List, Sequence @@ -44,6 +45,21 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[List[int]]: return ([shard_num * 10 + i] * 10 for i in range(row, 10)) +def test_tree_store_resolves_relative_cache_dir(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + exemplar = {"data": np.array([0], dtype=np.int64)} + relative_path = "cache_dir" + + store = TreeStore.open(exemplar, relative_path, mode="w") + + expected_path = os.path.abspath(relative_path) + assert store.path == expected_path + assert os.path.isabs(store.path) + + reloaded = store.reload() + assert reloaded.path == expected_path + + def test_tree_builder_with_processor(): with tempfile.TemporaryDirectory() as tempdir: exemplar = {"data": np.array([0], dtype=np.int64)}