Skip to content

Commit f883663

Browse files
committed
WIP: shared memory without tmpfs
1 parent bb17ff3 commit f883663

File tree

6 files changed

+365
-0
lines changed

6 files changed

+365
-0
lines changed

doc/source/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ API
5959
:members:
6060
.. autoclass:: LRU
6161
:members:
62+
.. autoclass:: SharedMemory
63+
:members:
6264
.. autoclass:: Sieve
6365
:members:
6466
.. autoclass:: Zip

zict/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from zict.func import Func as Func
77
from zict.lmdb import LMDB as LMDB
88
from zict.lru import LRU as LRU
9+
from zict.shared_memory import SharedMemory as SharedMemory
910
from zict.sieve import Sieve as Sieve
1011
from zict.utils import InsertionSortedSet as InsertionSortedSet
1112
from zict.zip import Zip as Zip

zict/shared_memory/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from zict.shared_memory.shared_memory import SharedMemory

zict/shared_memory/_linux.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Linux implementation of :class:`zict.SharedMemory`.
2+
3+
Wraps around glibc ``memfd_create``.
4+
"""
5+
from __future__ import annotations
6+
7+
import ctypes
8+
import mmap
9+
import os
10+
from collections.abc import Iterable
11+
12+
_memfd_create = None
13+
14+
15+
def _setitem(safe_key: str, value: Iterable[bytes | bytearray | memoryview]) -> int:
16+
global _memfd_create
17+
if _memfd_create is None:
18+
libc = ctypes.CDLL("libc.so.6")
19+
_memfd_create = libc.memfd_create
20+
21+
fd = _memfd_create(safe_key.encode("ascii"), 0)
22+
if fd == -1:
23+
raise OSError("Call to memfd_create failed") # pragma: nocover
24+
25+
with os.fdopen(fd, "wb", closefd=False) as fh:
26+
fh.writelines(value)
27+
28+
return fd
29+
30+
31+
def _getitem(fd: int) -> memoryview:
32+
# This opens a second fd for as long as the memory map is referenced.
33+
# Sadly there does not seem a way to extract the fd from the mmap, so we have to
34+
# keep the original fd open for the purpose of exporting.
35+
return memoryview(mmap.mmap(fd, 0))
36+
37+
38+
def _delitem(fd: int) -> None:
39+
# Close the original fd. There may be other fd's still open if the shared memory is
40+
# referenced somewhere else.
41+
# This is also called by SharedMemory.__del__.
42+
os.close(fd)
43+
44+
45+
def _export(safe_key: str, fd: int) -> tuple:
46+
return safe_key, os.getpid(), fd
47+
48+
49+
def _import(safe_key: str, pid: int, fd: int) -> int:
50+
# if fd has been closed, raise FileNotFoundError
51+
# if fd has been closed and reopened to something else, this may also raise a
52+
# generic OSError, e.g. if this is now a socket
53+
new_fd = os.open(f"/proc/{pid}/fd/{fd}", os.O_RDWR)
54+
55+
expect = f"/memfd:{safe_key} (deleted)"
56+
actual = os.readlink(f"/proc/{os.getpid()}/fd/{new_fd}")
57+
if actual != expect:
58+
# fd has been closed and reopened to something else
59+
os.close(new_fd)
60+
raise OSError()
61+
62+
return new_fd

zict/shared_memory/_windows.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Windows implementation of :class:`zict.SharedMemory`.
2+
3+
Conveniently, :class:`multiprocessing.shared_memory.SharedMemory` already wraps around
4+
the Windows API we want to use, so this is implemented as a hack on top of it.
5+
"""
6+
from __future__ import annotations
7+
8+
import mmap
9+
import multiprocessing.shared_memory
10+
from collections.abc import Collection
11+
from typing import cast
12+
13+
14+
class _PySharedMemoryNoClose(multiprocessing.shared_memory.SharedMemory):
15+
def __del__(self) -> None:
16+
pass
17+
18+
19+
def _setitem(
20+
safe_key: str, value: Collection[bytes | bytearray | memoryview]
21+
) -> memoryview:
22+
nbytes = sum(v.nbytes if isinstance(v, memoryview) else len(v) for v in value)
23+
shm = _PySharedMemoryNoClose(safe_key, create=True, size=nbytes)
24+
mm = cast(mmap.mmap, shm.buf.obj)
25+
for v in value:
26+
mm.write(v)
27+
# This dereferences shm; if we hadn't overridden the __del__ method, it would cause
28+
# it to automatically close the memory map and deallocate the shared memory.
29+
return shm.buf
30+
31+
32+
def _getitem(mm: memoryview) -> memoryview:
33+
# Nothing to do. This is just for compatibility with the Linux implementation, which
34+
# instead creates a memory map on the fly.
35+
return mm
36+
37+
38+
def _delitem(mm: memoryview) -> None:
39+
# Nothing to do. The shared memory is released as soon as the last memory map
40+
# referencing it is destroyed.
41+
pass
42+
43+
44+
def _export(safe_key: str, mm: memoryview) -> tuple:
45+
return (safe_key,)
46+
47+
48+
def _import(safe_key: str) -> memoryview:
49+
# Raises OSError in case of invalid key
50+
shm = _PySharedMemoryNoClose(safe_key)
51+
return shm.buf
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from __future__ import annotations
2+
3+
import secrets
4+
import sys
5+
from collections.abc import Iterator, KeysView
6+
from typing import Any
7+
from urllib.parse import quote, unquote
8+
9+
from zict.common import ZictBase
10+
11+
if sys.platform == "linux":
12+
from zict.shared_memory._linux import _delitem, _export, _getitem, _import, _setitem
13+
elif sys.platform == "win32":
14+
from zict.shared_memory._windows import (
15+
_delitem,
16+
_export,
17+
_getitem,
18+
_import,
19+
_setitem,
20+
)
21+
22+
23+
class SharedMemory(ZictBase[str, memoryview]):
24+
"""Mutable Mapping interface to shared memory.
25+
26+
**Supported OSs:** Linux, Windows
27+
28+
Keys must be strings, values must be buffers.
29+
Keys are stored in private memory, and other SharedMemory objects by default won't
30+
see them - even in case of key collision, the two pieces of data remain separate.
31+
32+
In order to share the same buffer, one SharedMemory object must call
33+
:meth:`export` and the other :meth:`import_`.
34+
35+
**Resources usage**
36+
37+
On Linux, you will hold 1 file descriptor open for every key in the SharedMemory
38+
mapping, plus 1 file descriptor for every returned memoryview that is referenced
39+
somewhere else. Please ensure that your ``ulimit`` is high enough to cope with this.
40+
41+
If you expect to call ``__getitem__`` multiple times on the same key while the
42+
return value from the previous call is still in use, you should wrap this mapping in
43+
a :class:`~zict.Cache`:
44+
45+
>>> import zict
46+
>>> shm = zict.Cache(
47+
... zict.SharedMemory(),
48+
... zict.WeakValueMapping(),
49+
... update_on_set=False,
50+
... ) # doctest: +SKIP
51+
52+
The above will cap the amount of open file descriptors per key to 2.
53+
54+
**Lifecycle**
55+
56+
Memory is released when all the SharedMemory objects that were sharing the key have
57+
deleted it *and* the buffer returned by ``__getitem__`` is no longer referenced
58+
anywhere else.
59+
Process termination, including ungraceful termination (SIGKILL, SIGSEGV), also
60+
releases the memory; in other words you don't risk leaking memory to the
61+
OS if all processes that were sharing it crash or are killed.
62+
63+
Examples
64+
--------
65+
In process 1:
66+
67+
>>> import pickle, numpy, zict # doctest: +SKIP
68+
>>> shm = zict.SharedMemory() # doctest: +SKIP
69+
>>> a = numpy.random.random(2**27) # 1 GiB # doctest: +SKIP
70+
>>> buffers = [] # doctest: +SKIP
71+
>>> pik = pickle.dumps(a, protocol=5, buffer_callback=buffers.append)
72+
... # doctest: +SKIP
73+
>>> # This deep-copies the buffer, resulting in 1 GiB private + 1 GiB shared memory.
74+
>>> shm["a"] = buffers # doctest: +SKIP
75+
>>> # Release private memory, leaving only the shared memory allocated
76+
>>> del a, buffers # doctest: +SKIP
77+
>>> # Recreate array from shared memory. This requires no extra memory.
78+
>>> a = pickle.loads(pik, buffers=[shm["a"]]) # doctest: +SKIP
79+
>>> # Send trivially-sized metadata (<1 kiB) to the peer process somehow.
80+
>>> send_to_process_2((pik, shm.export("a"))) # doctest: +SKIP
81+
82+
In process 2:
83+
84+
>>> import pickle, zict # doctest: +SKIP
85+
>>> shm = zict.SharedMemory() # doctest: +SKIP
86+
>>> pik, metadata = receive_from_process_1() # doctest: +SKIP
87+
>>> key = shm.import_(metadata) # returns "a" # doctest: +SKIP
88+
>>> a = pickle.loads(pik, buffers=[shm[key]]) # doctest: +SKIP
89+
90+
Now process 1 and 2 hold a reference to the same memory; in-place changes on one
91+
process are reflected onto the other. The shared memory is released after you delete
92+
the key and dereference the buffer returned by ``__getitem__`` on *both* processes:
93+
94+
>>> del shm["a"] # doctest: +SKIP
95+
>>> del a # doctest: +SKIP
96+
97+
or alternatively when both processes are terminated.
98+
99+
**Implementation notes**
100+
101+
This mapping uses OS-specific shared memory, which
102+
103+
1. can be shared among already existing processes, e.g. unlike ``mmap(fd=-1)``, and
104+
2. is automatically cleaned up by the OS in case of ungraceful process termination,
105+
e.g. unlike ``shm_open`` (which is used by :mod:`multiprocessing.shared_memory`
106+
on all POSIX OS'es)
107+
108+
It is implemented on top of ``memfd_create`` on Linux and ``CreateFileMapping`` on
109+
Windows. Notably, there is no POSIX equivalent for these API calls, as it only
110+
implements ``shm_open`` which would inevitably cause memory leaks in case of
111+
ungraceful process termination.
112+
"""
113+
114+
# {key: (unique safe key, implementation-specific data)}
115+
_data: dict[str, tuple[str, Any]]
116+
117+
def __init__(self): # type: ignore[no-untyped-def]
118+
if sys.platform not in ("linux", "win32"):
119+
raise NotImplementedError(
120+
"SharedMemory is only available on Linux and Windows"
121+
)
122+
123+
self._data = {}
124+
125+
def __str__(self) -> str:
126+
return f"<SharedMemory: {len(self)} elements>"
127+
128+
__repr__ = __str__
129+
130+
def __setitem__(
131+
self,
132+
key: str,
133+
value: bytes
134+
| bytearray
135+
| memoryview
136+
| list[bytes | bytearray | memoryview]
137+
| tuple[bytes | bytearray | memoryview, ...],
138+
) -> None:
139+
try:
140+
del self[key]
141+
except KeyError:
142+
pass
143+
144+
if not isinstance(value, (tuple, list)):
145+
value = [value]
146+
safe_key = quote(key, safe="") + "#" + secrets.token_bytes(8).hex()
147+
impl_data = _setitem(safe_key, value)
148+
self._data[key] = safe_key, impl_data
149+
150+
def __getitem__(self, key: str) -> memoryview:
151+
_, impl_data = self._data[key]
152+
return _getitem(impl_data)
153+
154+
def __delitem__(self, key: str) -> None:
155+
_, impl_data = self._data.pop(key)
156+
_delitem(impl_data)
157+
158+
def __del__(self) -> None:
159+
try:
160+
data_values = self._data.values()
161+
except Exception:
162+
# Interpreter shutdown
163+
return # pragma: nocover
164+
165+
for _, impl_data in data_values:
166+
try:
167+
_delitem(impl_data)
168+
except Exception:
169+
pass # pragma: nocover
170+
171+
def close(self) -> None:
172+
# Implements ZictBase.close(). Also triggered by __exit__.
173+
self.clear()
174+
175+
def __contains__(self, key: object) -> bool:
176+
return key in self._data
177+
178+
def keys(self) -> KeysView[str]:
179+
return self._data.keys()
180+
181+
def __iter__(self) -> Iterator[str]:
182+
return iter(self._data)
183+
184+
def __len__(self) -> int:
185+
return len(self._data)
186+
187+
def export(self, key: str) -> tuple:
188+
"""Export metadata for a key, which can be fed into :meth:`import_` on
189+
another process.
190+
191+
Returns
192+
-------
193+
Opaque metadata object (implementation-specific) to be passed to
194+
:meth:`import_`. It is serializable with JSON, YAML, and msgpack.
195+
196+
See Also
197+
--------
198+
import_
199+
"""
200+
return _export(*self._data[key])
201+
202+
def import_(self, metadata: tuple | list) -> str:
203+
"""Import a key from another process, starting to share the memory area.
204+
205+
You should treat parameters as implementation details and just unpack the tuple
206+
that was generated by :meth:`export`.
207+
208+
Returns
209+
-------
210+
Key that was just added to the mapping
211+
212+
Raises
213+
------
214+
FileNotFoundError
215+
Either the key or the whole SharedMemory object were deleted on the process
216+
where you ran :meth:`export`, or the process was terminated.
217+
218+
Notes
219+
-----
220+
On Windows, this method will raise FileNotFoundError if the key has been deleted
221+
from the other SharedMemory mapping *and* it is no longer referenced anywhere.
222+
On Linux, this method will raise as soon as the key is deleted from the other
223+
SharedMemory mapping, even if it's still referenced.
224+
225+
e.g. this code is not portable, as it will work on Windows but not on Linux:
226+
227+
>>> buf = shm["x"] = buf # doctest: +SKIP
228+
>>> meta = shm.export("x") # doctest: +SKIP
229+
>>> del shm["x"] # doctest: +SKIP
230+
231+
See Also
232+
--------
233+
export
234+
"""
235+
safe_key = metadata[0]
236+
key = unquote(safe_key.split("#")[0])
237+
238+
try:
239+
del self[key]
240+
except KeyError:
241+
pass
242+
243+
try:
244+
impl_data = _import(*metadata)
245+
except OSError:
246+
raise FileNotFoundError(f"Peer process no longer holds the key: {key!r}")
247+
self._data[key] = safe_key, impl_data
248+
return key

0 commit comments

Comments
 (0)