Skip to content

Commit e6a9447

Browse files
authored
fix: remove archive.download dest before downloading if it's a symlink (#191)
1 parent 5e6220a commit e6a9447

File tree

2 files changed

+112
-46
lines changed

2 files changed

+112
-46
lines changed

devenv/lib/archive.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,51 +38,60 @@ def download(
3838
dest = f"{cache_root}/{sha256}"
3939
os.makedirs(cache_root, exist_ok=True)
4040

41-
if not os.path.exists(dest):
42-
headers = {}
43-
if url.startswith("https://ghcr.io/v2/homebrew"):
44-
# downloading homebrew blobs requires auth
45-
# you can get an anonymous token from https://ghcr.io/token?service=ghcr.io&scope=repository%3Ahomebrew/core/go%3Apull
46-
# but there's also a special shortcut token QQ==
47-
# https://github.com/Homebrew/brew/blob/2184406bd8444e4de2626f5b0c749d4d08cb1aed/Library/Homebrew/brew.sh#L993
48-
headers["Authorization"] = "bearer QQ=="
49-
50-
req = urllib.request.Request(url, headers=headers)
51-
52-
retry_sleep = 1.0
53-
while retries >= 0:
54-
try:
55-
resp = urllib.request.urlopen(req)
56-
break
57-
except HTTPError as e:
58-
if retries == 0:
59-
raise RuntimeError(f"Error getting {url}: {e}")
60-
print(f"Error getting {url} ({retries} retries left): {e}")
61-
62-
time.sleep(retry_sleep)
63-
retries -= 1
64-
retry_sleep *= retry_exp
65-
66-
dest_dir = os.path.dirname(dest)
67-
os.makedirs(dest_dir, exist_ok=True)
68-
69-
with tempfile.NamedTemporaryFile(delete=False, dir=dest_dir) as tmpf:
70-
shutil.copyfileobj(resp, tmpf)
71-
tmpf.seek(0)
72-
checksum = hashlib.sha256()
41+
if os.path.islink(dest):
42+
# there are cases where dest can be an existing symlink
43+
# (docker desktop starts and puts symlinks into ~/.docker/cli-plugins)
44+
# such symlinks should be removed otherwise callers to download
45+
# usually try to chmod after and end up with PermissionError
46+
os.remove(dest)
47+
48+
if os.path.exists(dest):
49+
return dest
50+
51+
headers = {}
52+
if url.startswith("https://ghcr.io/v2/homebrew"):
53+
# downloading homebrew blobs requires auth
54+
# you can get an anonymous token from https://ghcr.io/token?service=ghcr.io&scope=repository%3Ahomebrew/core/go%3Apull
55+
# but there's also a special shortcut token QQ==
56+
# https://github.com/Homebrew/brew/blob/2184406bd8444e4de2626f5b0c749d4d08cb1aed/Library/Homebrew/brew.sh#L993
57+
headers["Authorization"] = "bearer QQ=="
58+
59+
req = urllib.request.Request(url, headers=headers)
60+
61+
retry_sleep = 1.0
62+
while retries >= 0:
63+
try:
64+
resp = urllib.request.urlopen(req)
65+
break
66+
except HTTPError as e:
67+
if retries == 0:
68+
raise RuntimeError(f"Error getting {url}: {e}")
69+
print(f"Error getting {url} ({retries} retries left): {e}")
70+
71+
time.sleep(retry_sleep)
72+
retries -= 1
73+
retry_sleep *= retry_exp
74+
75+
dest_dir = os.path.dirname(dest)
76+
os.makedirs(dest_dir, exist_ok=True)
77+
78+
with tempfile.NamedTemporaryFile(delete=False, dir=dest_dir) as tmpf:
79+
shutil.copyfileobj(resp, tmpf)
80+
tmpf.seek(0)
81+
checksum = hashlib.sha256()
82+
buf = tmpf.read(4096)
83+
while buf:
84+
checksum.update(buf)
7385
buf = tmpf.read(4096)
74-
while buf:
75-
checksum.update(buf)
76-
buf = tmpf.read(4096)
77-
78-
if not secrets.compare_digest(checksum.hexdigest(), sha256):
79-
raise RuntimeError(
80-
f"checksum mismatch for {url}:\n"
81-
f"- got: {checksum.hexdigest()}\n"
82-
f"- expected: {sha256}\n"
83-
)
84-
85-
atomic_replace(tmpf.name, dest)
86+
87+
if not secrets.compare_digest(checksum.hexdigest(), sha256):
88+
raise RuntimeError(
89+
f"checksum mismatch for {url}:\n"
90+
f"- got: {checksum.hexdigest()}\n"
91+
f"- expected: {sha256}\n"
92+
)
93+
94+
atomic_replace(tmpf.name, dest)
8695

8796
return dest
8897

tests/lib/test_archive.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import io
4+
import os
45
import pathlib
56
import tarfile
67
import time
@@ -169,6 +170,12 @@ def test_download(tmp_path: pathlib.Path, mock_sleep: mock.MagicMock) -> None:
169170
)
170171

171172
dest = f"{tmp_path}/a"
173+
174+
# if dest is already a valid symlink it should be paved over
175+
with open(f"{tmp_path}/hi", "wb"):
176+
pass
177+
os.symlink(f"{tmp_path}/hi", dest)
178+
172179
with mock.patch.object(
173180
urllib.request,
174181
"urlopen",
@@ -186,7 +193,22 @@ def test_download(tmp_path: pathlib.Path, mock_sleep: mock.MagicMock) -> None:
186193
with open(dest, "rb") as f:
187194
assert f.read() == data
188195

189-
dest = f"{tmp_path}/b"
196+
197+
def test_download_exceeded_retries(
198+
tmp_path: pathlib.Path, mock_sleep: mock.MagicMock
199+
) -> None:
200+
data_sha256 = (
201+
"b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
202+
)
203+
204+
err = urllib.error.HTTPError(
205+
"https://example.com/foo",
206+
503,
207+
"Service Unavailable",
208+
"", # type: ignore
209+
io.BytesIO(b""),
210+
)
211+
dest = f"{tmp_path}/a"
190212
with pytest.raises(RuntimeError) as excinfo:
191213
with mock.patch.object(
192214
urllib.request,
@@ -202,7 +224,42 @@ def test_download(tmp_path: pathlib.Path, mock_sleep: mock.MagicMock) -> None:
202224
== "Error getting https://example.com/foo: HTTP Error 503: Service Unavailable"
203225
)
204226

205-
dest = f"{tmp_path}/b"
227+
228+
def test_download_dest_is_broken_symlink(
229+
tmp_path: pathlib.Path, mock_sleep: mock.MagicMock
230+
) -> None:
231+
data = b"foo\n"
232+
data_sha256 = (
233+
"b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
234+
)
235+
236+
dest = f"{tmp_path}/a"
237+
238+
# if dest is already a dead symlink it should be paved over as well
239+
os.symlink(f"{tmp_path}/does-not-exist", dest)
240+
241+
with mock.patch.object(
242+
urllib.request,
243+
"urlopen",
244+
autospec=True,
245+
side_effect=(io.BytesIO(data),),
246+
):
247+
archive.download("https://example.com/foo", data_sha256, dest)
248+
249+
with open(dest, "rb") as f:
250+
assert f.read() == data
251+
252+
253+
def test_download_wrong_sha(
254+
tmp_path: pathlib.Path, mock_sleep: mock.MagicMock
255+
) -> None:
256+
data = b"foo\n"
257+
data_sha256 = (
258+
"b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
259+
)
260+
261+
dest = f"{tmp_path}/a"
262+
206263
with pytest.raises(RuntimeError) as excinfo:
207264
with mock.patch.object(
208265
urllib.request,

0 commit comments

Comments
 (0)