99
1010import yaml
1111
12- _default_data_dir = Path ( __file__ ). resolve (). parent
12+ from . import data
1313
1414
1515class RemoteDatasetList :
@@ -60,9 +60,9 @@ def make_all_dirs(path: str) -> None:
6060
6161
6262def fetch_remote_dataset (
63- dataset_name : str , files : dict [str , str ], url : str , data_dir : str
63+ dataset_name : str , files : dict [str , str ], url : str , cache_dir : str
6464) -> None :
65- dataset_dir = Path (data_dir ) / dataset_name
65+ dataset_dir = Path (cache_dir ) / dataset_name
6666
6767 writefile = dataset_dir / Path (url ).name
6868 if writefile .exists ():
@@ -72,9 +72,9 @@ def fetch_remote_dataset(
7272 logging .warning ("Downloading %s" , url )
7373 urlretrieve (url , str (writefile ))
7474
75- if tarfile .is_tarfile (writefile ):
75+ if tarfile .is_tarfile (str ( writefile ) ):
7676 logging .warning ("Extracting %s" , writefile )
77- with tarfile .open (writefile ) as tar :
77+ with tarfile .open (str ( writefile ) ) as tar :
7878 members = [tar .getmember (f ) for f in files .values ()]
7979 tar .extractall (str (dataset_dir ), members )
8080
@@ -93,16 +93,17 @@ def is_known_remote(filename: str) -> bool:
9393
9494
9595def remote_file (
96- filename : str , data_dir : str | Path = _default_data_dir , raise_missing : bool = False
96+ filename : str , cache_dir : str | Path | None = None , raise_missing : bool = False
9797) -> str :
98+ cache_dir = data .cache_path (cache_dir )
9899 config = RemoteDatasetList .get_config_for_file (filename )
99100 if not config and raise_missing :
100101 msg = f"Unknown { filename } cannot be found"
101102 raise RuntimeError (msg )
102103
103- path = Path (data_dir ) / filename
104+ path = Path (cache_dir ) / filename
104105 if not path .is_file ():
105- config ["data_dir " ] = str (data_dir )
106+ config ["cache_dir " ] = str (cache_dir )
106107 fetch_remote_dataset (** config ) # type: ignore[arg-type]
107108
108109 if not path .is_file () and raise_missing :
0 commit comments