Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os
from pathlib import Path

import setuptools.command.sdist
from setuptools import setup
Expand All @@ -12,7 +13,7 @@
data_files = {n for n in os.listdir(datafile) if any(n.endswith(ex) for ex in data_ex)}

if data_files:
with open(os.path.join(datafile, "file_list.txt"), "w") as f:
with open(Path(datafile) / "file_list.txt", "w") as f:
for d in sorted(data_files):
print(d.split("/")[-1], file=f)

Expand Down
51 changes: 21 additions & 30 deletions src/skhep_testdata/remote_files.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from __future__ import annotations

import errno
import logging
import os
import tarfile
from importlib import resources
from pathlib import Path
from typing import ClassVar
from urllib.request import urlretrieve

import yaml

_default_data_dir = os.path.realpath(os.path.dirname(__file__))
_default_data_dir = Path(__file__).resolve().parent


class RemoteDatasetList:
Expand All @@ -33,9 +32,7 @@ def load_remote_configs(cls, file_to_load: str | None = None) -> None:
return

if file_to_load is None:
dataset_yml = resources.files("skhep_testdata").joinpath(
"remote_datasets.yml"
)
dataset_yml = resources.files("skhep_testdata") / "remote_datasets.yml"
with dataset_yml.open() as infile:
datasets = yaml.load(infile, Loader=yaml.SafeLoader)
else:
Expand All @@ -49,7 +46,7 @@ def load_remote_configs(cls, file_to_load: str | None = None) -> None:
config["files"] = files
config["dataset_name"] = dataset
for filename in files:
scoped_name = os.path.join(dataset, filename)
scoped_name = str(Path(dataset) / filename)
cls._all_files[scoped_name] = config

@classmethod
Expand All @@ -59,40 +56,34 @@ def is_known(cls, filename: str) -> bool:


def make_all_dirs(path: str) -> None:
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
Path(path).mkdir(parents=True, exist_ok=True)


def fetch_remote_dataset(
dataset_name: str, files: dict[str, str], url: str, data_dir: str
) -> None:
dataset_dir = os.path.join(data_dir, dataset_name)
dataset_dir = Path(data_dir) / dataset_name

writefile = os.path.join(dataset_dir, os.path.basename(url))
if os.path.exists(writefile):
writefile = dataset_dir / Path(url).name
if writefile.exists():
return

make_all_dirs(dataset_dir)
make_all_dirs(str(dataset_dir))
logging.warning("Downloading %s", url)
urlretrieve(url, writefile)
urlretrieve(url, str(writefile))

if tarfile.is_tarfile(writefile):
logging.warning("Extracting %s", writefile)
with tarfile.open(writefile) as tar:
members = [tar.getmember(f) for f in files.values()]
tar.extractall(dataset_dir, members)
tar.extractall(str(dataset_dir), members)

for outfile, infile in files.items():
full_in = os.path.join(dataset_dir, infile)
full_out = os.path.join(dataset_dir, outfile)
os.rename(full_in, full_out)
full_in = dataset_dir / infile
full_out = dataset_dir / outfile
full_in.rename(full_out)

if not os.path.exists(writefile):
if not writefile.exists():
msg = "Problem obtaining remote dataset : %s"
raise RuntimeError(msg % dataset_name)

Expand All @@ -102,20 +93,20 @@ def is_known_remote(filename: str) -> bool:


def remote_file(
filename: str, data_dir: str = _default_data_dir, raise_missing: bool = False
filename: str, data_dir: str | Path = _default_data_dir, raise_missing: bool = False
) -> str:
config = RemoteDatasetList.get_config_for_file(filename)
if not config and raise_missing:
msg = f"Unknown {filename} cannot be found"
raise RuntimeError(msg)

path = os.path.join(data_dir, filename)
if not os.path.isfile(path):
config["data_dir"] = data_dir
path = Path(data_dir) / filename
if not path.is_file():
config["data_dir"] = str(data_dir)
fetch_remote_dataset(**config) # type: ignore[arg-type]

if not os.path.isfile(path) and raise_missing:
if not path.is_file() and raise_missing:
msg = f"{filename} cannot be found"
raise RuntimeError(msg)

return path
return str(path)
13 changes: 6 additions & 7 deletions tests/test_local_files.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from __future__ import annotations

import os
from pathlib import Path

import pytest
import requests

import skhep_testdata as skhtd

data_dir = os.path.dirname(skhtd.__file__)
data_dir = os.path.join(data_dir, "data")
data_dir = Path(skhtd.__file__).parent / "data"


def test_data_path():
assert os.path.exists(skhtd.data_path("uproot-Zmumu.root"))
assert Path(skhtd.data_path("uproot-Zmumu.root")).exists()


def test_data_path_missing():
path = skhtd.data_path("doesnt-exist.root", raise_missing=False)
assert path == os.path.join(data_dir, "doesnt-exist.root")
assert path == str(data_dir / "doesnt-exist.root")

with pytest.raises(IOError):
skhtd.data_path("doesnt-exist.root")
Expand All @@ -36,10 +35,10 @@ def test_delegate_to_remote(monkeypatch, tmpdir):
def dummy_remote_file(filename, data_dir=None, raise_missing=False):
if not data_dir:
data_dir = str(tmpdir)
return os.path.join(data_dir, filename)
return str(Path(data_dir) / filename)

monkeypatch.setattr(skhtd.remote_files, "remote_file", dummy_remote_file)
monkeypatch.setattr(skhtd.remote_files, "is_known_remote", lambda _: True)

path = skhtd.data_path(os.path.join("dataset", "a_remote_file.root"))
path = skhtd.data_path(str(Path("dataset") / "a_remote_file.root"))
assert path == str(tmpdir / "dataset" / "a_remote_file.root")
10 changes: 4 additions & 6 deletions tests/test_remote_files.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import annotations

import os
from pathlib import Path

import skhep_testdata as skhtd

_remote_dataset_cfg = os.path.join(
os.path.dirname(__file__), "test_remote_datasets.yml"
)
_remote_dataset_cfg = str(Path(__file__).parent / "test_remote_datasets.yml")
skhtd.remote_files.RemoteDatasetList.load_remote_configs(_remote_dataset_cfg)

good_file_1 = os.path.join("dataset_1", "file_1.root")
bad_file_1 = os.path.join("bad_dataset_1", "file_1.root")
good_file_1 = str(Path("dataset_1") / "file_1.root")
bad_file_1 = str(Path("bad_dataset_1") / "file_1.root")


def test_is_known_remote():
Expand Down
Loading