Skip to content

Commit 8c71761

Browse files
authored
Use the same cache directory for remote and local files (#197)
1 parent d60b25d commit 8c71761

File tree

5 files changed

+31
-25
lines changed

5 files changed

+31
-25
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
6+
def cache_path(cache_dir: Path | str | None = None) -> Path:
7+
if cache_dir is None:
8+
skhepdir = Path.home() / ".local" / "skhepdata"
9+
skhepdir.mkdir(exist_ok=True, parents=True)
10+
return skhepdir
11+
12+
return Path(cache_dir)

src/skhep_testdata/local_files.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,21 @@
2222
known_files = {n.strip() for n in f if n.strip()}
2323

2424

25-
def _cache_path(cache_dir: str | None = None) -> Path:
26-
if cache_dir is None:
27-
skhepdir = Path.home() / ".local" / "skhepdata"
28-
skhepdir.mkdir(exist_ok=True, parents=True)
29-
return skhepdir
30-
31-
return Path(cache_dir)
32-
33-
3425
def data_path(
3526
filename: str, raise_missing: bool = True, cache_dir: str | None = None
3627
) -> str:
3728
if remote_files.is_known_remote(filename):
38-
return remote_files.remote_file(filename, raise_missing=raise_missing)
29+
return remote_files.remote_file(
30+
filename, cache_dir=cache_dir, raise_missing=raise_missing
31+
)
3932

4033
if filename not in known_files and raise_missing:
4134
raise FileNotFoundError(filename)
4235

4336
filepath = DIR / "data" / filename
4437

4538
if not filepath.exists() and filename in known_files:
46-
cached_path = _cache_path(cache_dir) / filename
39+
cached_path = data.cache_path(cache_dir) / filename
4740

4841
if not cached_path.exists():
4942
# Currently get from main branch
@@ -60,7 +53,7 @@ def data_path(
6053

6154

6255
def download_all(cache_dir: str | None = None) -> None:
63-
local_dir = _cache_path(cache_dir)
56+
local_dir = data.cache_path(cache_dir)
6457

6558
with tempfile.TemporaryFile() as f:
6659
with requests.get(zipurl, stream=True) as r:

src/skhep_testdata/remote_files.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import yaml
1111

12-
_default_data_dir = Path(__file__).resolve().parent
12+
from . import data
1313

1414

1515
class RemoteDatasetList:
@@ -60,9 +60,9 @@ def make_all_dirs(path: str) -> None:
6060

6161

6262
def fetch_remote_dataset(
63-
dataset_name: str, files: dict[str, str], url: str, data_dir: str
63+
dataset_name: str, files: dict[str, str], url: str, cache_dir: str
6464
) -> None:
65-
dataset_dir = Path(data_dir) / dataset_name
65+
dataset_dir = Path(cache_dir) / dataset_name
6666

6767
writefile = dataset_dir / Path(url).name
6868
if writefile.exists():
@@ -72,9 +72,9 @@ def fetch_remote_dataset(
7272
logging.warning("Downloading %s", url)
7373
urlretrieve(url, str(writefile))
7474

75-
if tarfile.is_tarfile(writefile):
75+
if tarfile.is_tarfile(str(writefile)):
7676
logging.warning("Extracting %s", writefile)
77-
with tarfile.open(writefile) as tar:
77+
with tarfile.open(str(writefile)) as tar:
7878
members = [tar.getmember(f) for f in files.values()]
7979
tar.extractall(str(dataset_dir), members)
8080

@@ -93,16 +93,17 @@ def is_known_remote(filename: str) -> bool:
9393

9494

9595
def remote_file(
96-
filename: str, data_dir: str | Path = _default_data_dir, raise_missing: bool = False
96+
filename: str, cache_dir: str | Path | None = None, raise_missing: bool = False
9797
) -> str:
98+
cache_dir = data.cache_path(cache_dir)
9899
config = RemoteDatasetList.get_config_for_file(filename)
99100
if not config and raise_missing:
100101
msg = f"Unknown {filename} cannot be found"
101102
raise RuntimeError(msg)
102103

103-
path = Path(data_dir) / filename
104+
path = Path(cache_dir) / filename
104105
if not path.is_file():
105-
config["data_dir"] = str(data_dir)
106+
config["cache_dir"] = str(cache_dir)
106107
fetch_remote_dataset(**config) # type: ignore[arg-type]
107108

108109
if not path.is_file() and raise_missing:

tests/test_local_files.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def test_data_path_cached():
3232

3333

3434
def test_delegate_to_remote(monkeypatch, tmpdir):
35-
def dummy_remote_file(filename, data_dir=None, raise_missing=False):
36-
if not data_dir:
37-
data_dir = str(tmpdir)
38-
return str(Path(data_dir) / filename)
35+
def dummy_remote_file(filename, cache_dir=None, raise_missing=False):
36+
if not cache_dir:
37+
cache_dir = str(tmpdir)
38+
return str(Path(cache_dir) / filename)
3939

4040
monkeypatch.setattr(skhtd.remote_files, "remote_file", dummy_remote_file)
4141
monkeypatch.setattr(skhtd.remote_files, "is_known_remote", lambda _: True)

tests/test_remote_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ def fake_urlretrieve(url, writefile):
3737

3838
monkeypatch.setattr(skhtd.remote_files, "urlretrieve", fake_urlretrieve)
3939

40-
path = skhtd.remote_files.remote_file(good_file_1, data_dir=str(tmpdir))
40+
path = skhtd.remote_files.remote_file(good_file_1, cache_dir=str(tmpdir))
4141
assert path == str(tmpdir / "dataset_1" / "file_1.root")

0 commit comments

Comments
 (0)