Skip to content

Commit 977089f

Browse files
committed
Enums
1 parent c6622e8 commit 977089f

File tree

3 files changed

+166
-61
lines changed

3 files changed

+166
-61
lines changed

dill/_dill.py

Lines changed: 150 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def _trace(boolean):
7373
GeneratorType, DictProxyType, XRangeType, SliceType, TracebackType, \
7474
NotImplementedType, EllipsisType, FrameType, ModuleType, \
7575
BufferType, BuiltinMethodType, TypeType
76-
from pickle import HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
76+
from pickle import HIGHEST_PROTOCOL, PickleError, PicklingError, \
77+
UnpicklingError
7778
try:
7879
from pickle import DEFAULT_PROTOCOL
7980
except ImportError:
@@ -263,6 +264,14 @@ def get_file_type(*args, **kwargs):
263264
except NameError: ExitType = None
264265
singletontypes = []
265266

267+
from collections import OrderedDict
268+
269+
try:
270+
from enum import Enum, EnumMeta
271+
except:
272+
Enum = None
273+
EnumMeta = None
274+
266275
import inspect
267276

268277
### Shims for different versions of Python and dill
@@ -1002,26 +1011,13 @@ def _get_attr(self, name):
10021011
return getattr(self, name, None) or getattr(__builtin__, name)
10031012

10041013
def _dict_from_dictproxy(dictproxy):
1014+
# Deprecated. Use _get_typedict_type instead.
10051015
_dict = dictproxy.copy() # convert dictproxy to dict
10061016
_dict.pop('__dict__', None)
10071017
_dict.pop('__weakref__', None)
10081018
_dict.pop('__prepare__', None)
10091019
return _dict
10101020

1011-
def _dict_from_dictproxy_abc(dictproxy):
1012-
_dict = dictproxy.copy() # convert dictproxy to dict
1013-
_dict.pop('__dict__', None)
1014-
_dict.pop('__weakref__', None)
1015-
_dict.pop('__prepare__', None)
1016-
if '_abc_registry' in _dict:
1017-
del _dict['_abc_registry']
1018-
del _dict['_abc_cache']
1019-
del _dict['_abc_negative_cache']
1020-
del _dict['_abc_negative_cache_version']
1021-
else:
1022-
del _dict['_abc_impl']
1023-
return _dict
1024-
10251021
def _import_module(import_name, safe=False):
10261022
try:
10271023
if '.' in import_name:
@@ -1036,13 +1032,40 @@ def _import_module(import_name, safe=False):
10361032
return None
10371033
raise
10381034

1035+
# https://github.com/python/cpython/blob/a8912a0f8d9eba6d502c37d522221f9933e976db/Lib/pickle.py#L322-L333
1036+
def _getattribute(obj, name):
1037+
for subpath in name.split('.'):
1038+
if subpath == '<locals>':
1039+
raise AttributeError("Can't get local attribute {!r} on {!r}"
1040+
.format(name, obj))
1041+
try:
1042+
parent = obj
1043+
obj = getattr(obj, subpath)
1044+
except AttributeError:
1045+
raise AttributeError("Can't get attribute {!r} on {!r}"
1046+
.format(name, obj))
1047+
return obj, parent
1048+
10391049
def _locate_function(obj, session=False):
1040-
if obj.__module__ in ['__main__', None]: # and session:
1050+
module_name = getattr(obj, '__module__', None)
1051+
if module_name in ['__main__', None]: # and session:
10411052
return False
1042-
found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True)
1043-
return found is obj
1053+
if hasattr(obj, '__qualname__'):
1054+
module = _import_module(module_name, safe=True)
1055+
try:
1056+
found, _ = _getattribute(module, obj.__qualname__)
1057+
return found is obj
1058+
except:
1059+
return False
1060+
else:
1061+
found = _import_module(module_name + '.' + obj.__name__, safe=True)
1062+
return found is obj
10441063

10451064

1065+
def _setitems(dest, source):
1066+
for k, v in source.items():
1067+
dest[k] = v
1068+
10461069
def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None):
10471070
if obj is Getattr.NO_DEFAULT:
10481071
obj = Reduce(reduction) # pragma: no cover
@@ -1074,7 +1097,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO
10741097
postproc = pickler._postproc.pop(id(obj))
10751098
# assert postproc_list == postproc, 'Stack tampered!'
10761099
for reduction in reversed(postproc):
1077-
if reduction[0] is dict.update and type(reduction[1][0]) is dict:
1100+
if reduction[0] is _setitems:
10781101
# use the internal machinery of pickle.py to speedup when
10791102
# updating a dictionary in postproc
10801103
dest, source = reduction[1]
@@ -1612,32 +1635,77 @@ def save_module(pickler, obj):
16121635
return
16131636
return
16141637

1615-
@register(abc.ABCMeta)
1616-
def save_abc(pickler, obj):
1617-
"""Use StockePickler to ignore ABC internal state which should not be serialized"""
1618-
1619-
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
1620-
if not _locate_function(obj): # not a function, but the name was held over
1621-
log.info("ABC2: %s" % obj)
1622-
if hasattr(abc, '_get_dump'):
1623-
(registry, _, _, _) = abc._get_dump(obj)
1624-
register = obj.register
1625-
postproc_list = [(register, (reg(),)) for reg in registry]
1626-
elif hasattr(obj, '_abc_registry'):
1627-
registry = obj._abc_registry
1628-
register = obj.register
1629-
postproc_list = [(register, (reg,)) for reg in registry]
1630-
else:
1631-
postproc_list = None
1632-
save_type(pickler, obj, _dict_from_dictproxy_abc, postproc_list)
1633-
log.info("# ABC2")
1638+
# The following function is based on '_extract_class_dict' from 'cloudpickle'
1639+
# Copyright (c) 2012, Regents of the University of California.
1640+
# Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
1641+
# License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE
1642+
def _get_typedict_type(cls, clsdict, postproc_list):
1643+
"""Retrieve a copy of the dict of a class without the inherited methods"""
1644+
if len(cls.__bases__) == 1:
1645+
inherited_dict = cls.__bases__[0].__dict__
1646+
else:
1647+
inherited_dict = {}
1648+
for base in reversed(cls.__bases__):
1649+
inherited_dict.update(base.__dict__)
1650+
to_remove = []
1651+
for name, value in dict.items(clsdict):
1652+
try:
1653+
base_value = inherited_dict[name]
1654+
if value is base_value:
1655+
to_remove.append(name)
1656+
except KeyError:
1657+
pass
1658+
for name in to_remove:
1659+
dict.pop(clsdict, name)
1660+
1661+
if issubclass(type(cls), type):
1662+
clsdict.pop('__dict__', None)
1663+
clsdict.pop('__weakref__', None)
1664+
# clsdict.pop('__prepare__', None)
1665+
return clsdict
1666+
# return _dict_from_dictproxy(cls.__dict__)
1667+
1668+
def _get_typedict_abc(obj, _dict, state, postproc_list):
1669+
log.info("ABC: %s" % obj)
1670+
if hasattr(abc, '_get_dump'):
1671+
(registry, _, _, _) = abc._get_dump(obj)
1672+
register = obj.register
1673+
postproc_list.extend((register, (reg(),)) for reg in registry)
1674+
elif hasattr(obj, '_abc_registry'):
1675+
registry = obj._abc_registry
1676+
register = obj.register
1677+
postproc_list.extend((register, (reg,)) for reg in registry)
1678+
else:
1679+
raise PicklingError("Cannot find registry of ABC %s", obj)
1680+
1681+
if '_abc_registry' in _dict:
1682+
del _dict['_abc_registry']
1683+
del _dict['_abc_cache']
1684+
del _dict['_abc_negative_cache']
1685+
# del _dict['_abc_negative_cache_version']
16341686
else:
1635-
log.info("ABC1: %s" % obj)
1636-
save_type(pickler, obj)
1637-
log.info("# ABC1")
1687+
del _dict['_abc_impl']
1688+
log.info("# ABC")
1689+
return _dict, state
1690+
1691+
def _get_typedict_enum(obj, _dict, state, postproc_list):
1692+
log.info("E: %s" % obj)
1693+
metacls = type(obj)
1694+
original_dict = {}
1695+
for name, enum_value in obj.__members__.items():
1696+
original_dict[name] = enum_value.value
1697+
del _dict[name]
1698+
1699+
_dict.pop('_member_names_', None)
1700+
_dict.pop('_member_map_', None)
1701+
_dict.pop('_value2member_map_', None)
1702+
_dict.pop('_generate_next_value_', None)
1703+
1704+
log.info("# E")
1705+
return original_dict, (None, _dict)
16381706

16391707
@register(TypeType)
1640-
def save_type(pickler, obj, _dict_from_dictproxy_func=_dict_from_dictproxy, postproc_list=None):
1708+
def save_type(pickler, obj, postproc_list=None):
16411709
if obj in _typemap:
16421710
log.info("T1: %s" % obj)
16431711
pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj)
@@ -1673,23 +1741,46 @@ def save_type(pickler, obj, _dict_from_dictproxy_func=_dict_from_dictproxy, post
16731741
obj_recursive = id(obj) in getattr(pickler, '_postproc', ())
16741742
incorrectly_named = not _locate_function(obj)
16751743
if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over
1676-
if issubclass(type(obj), type):
1677-
# thanks to Tom Stepleton pointing out pickler._session unneeded
1678-
_t = 'T2'
1679-
log.info("%s: %s" % (_t, obj))
1680-
_dict = _dict_from_dictproxy_func(obj.__dict__)
1681-
else:
1682-
_t = 'T3'
1683-
log.info("%s: %s" % (_t, obj))
1684-
_dict = obj.__dict__
1744+
if postproc_list is None:
1745+
postproc_list = []
1746+
1747+
# thanks to Tom Stepleton pointing out pickler._session unneeded
1748+
_t = 'T3'
1749+
_dict = _get_typedict_type(obj, obj.__dict__.copy(), postproc_list) # copy dict proxy to a dict
1750+
state = None
1751+
1752+
for name in _dict.get("__slots__", []):
1753+
del _dict[name]
1754+
1755+
if isinstance(obj, abc.ABCMeta):
1756+
_dict, state = _get_typedict_abc(obj, _dict, state, postproc_list)
1757+
1758+
if EnumMeta and isinstance(obj, EnumMeta):
1759+
_dict, state = _get_typedict_enum(obj, _dict, state, postproc_list)
1760+
16851761
#print (_dict)
16861762
#print ("%s\n%s" % (type(obj), obj.__name__))
16871763
#print ("%s\n%s" % (obj.__bases__, obj.__dict__))
1688-
for name in _dict.get("__slots__", []):
1689-
del _dict[name]
1690-
_save_with_postproc(pickler, (_create_type, (
1691-
type(obj), obj_name, obj.__bases__, _dict
1692-
)), obj=obj, postproc_list=postproc_list)
1764+
1765+
if PY3 and type(obj) is not type or hasattr(obj, '__orig_bases__'):
1766+
from types import new_class
1767+
_metadict = {
1768+
'metaclass': type(obj)
1769+
}
1770+
1771+
if _dict:
1772+
_dict_update = PartialType(_setitems, source=_dict)
1773+
else:
1774+
_dict_update = None
1775+
1776+
bases = getattr(obj, '__orig_bases__', obj.__bases__)
1777+
_save_with_postproc(pickler, (new_class, (
1778+
obj_name, bases, _metadict, _dict_update
1779+
)), state, obj=obj, postproc_list=postproc_list)
1780+
else:
1781+
_save_with_postproc(pickler, (_create_type, (
1782+
type(obj), obj_name, obj.__bases__, _dict
1783+
)), state, obj=obj, postproc_list=postproc_list)
16931784
log.info("# %s" % _t)
16941785
else:
16951786
log.info("T4: %s" % obj)
@@ -1784,12 +1875,13 @@ def save_function(pickler, obj):
17841875
glob_ids = {id(g) for g in globs_copy.values()}
17851876
else:
17861877
glob_ids = {id(g) for g in globs_copy.itervalues()}
1878+
17871879
for stack_element in _postproc:
17881880
if stack_element in glob_ids:
1789-
_postproc[stack_element].append((dict.update, (globs, globs_copy)))
1881+
_postproc[stack_element].append((_setitems, (globs, globs_copy)))
17901882
break
17911883
else:
1792-
postproc_list.append((dict.update, (globs, globs_copy)))
1884+
postproc_list.append((_setitems, (globs, globs_copy)))
17931885

17941886
if PY3:
17951887
fkwdefaults = getattr(obj, '__kwdefaults__', None)

tests/test_abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def test_abc_non_local():
7979
# Set a property that StockPickle can't preserve
8080
instance.bar = lambda x: x**2
8181
depickled = dill.copy(instance)
82-
assert type(depickled) == type(instance)
83-
assert type(depickled.bar) == FunctionType
82+
assert type(depickled) is not type(instance)
83+
assert type(depickled.bar) is FunctionType
8484
assert depickled.bar(3) == 9
8585
assert depickled.sfoo() == "Static Method SFOO"
8686
assert depickled.cfoo() == "Class Method CFOO"

tests/test_classdef.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def ok(self):
5454
nc = _newclass2()
5555
m = _mclass()
5656

57+
if sys.hexversion >= 0x03070000:
58+
import typing
59+
class customIntList(typing.List[int]):
60+
pass
61+
elif sys.hexversion >= 0x03090000:
62+
class customIntList(list[int]):
63+
pass
64+
5765
# test pickles for class instances
5866
def test_class_instances():
5967
assert dill.pickles(o)
@@ -127,7 +135,7 @@ def test_dtype():
127135
def test_array_nested():
128136
try:
129137
import numpy as np
130-
138+
131139
x = np.array([1])
132140
y = (x,)
133141
dill.dumps(x)
@@ -202,6 +210,10 @@ def test_slots():
202210
assert dill.pickles(Y.y)
203211
assert dill.copy(y).y == value
204212

213+
def test_origbases():
214+
if sys.hexversion >= 0x03070000:
215+
assert dill.copy(customIntList).__orig_bases__ == customIntList.__orig_bases__
216+
205217

206218
if __name__ == '__main__':
207219
test_class_instances()
@@ -213,3 +225,4 @@ def test_slots():
213225
test_array_subclass()
214226
test_method_decorator()
215227
test_slots()
228+
test_origbases()

0 commit comments

Comments
 (0)