Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/config/custom_containers.py: 68%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
3import copy
5import sacred.optional as opt
6from sacred.utils import join_paths, SacredError
9def fallback_dict(fallback, **kwargs):
10 fallback_copy = fallback.copy()
11 fallback_copy.update(kwargs)
12 return fallback_copy
15class DogmaticDict(dict):
16 def __init__(self, fixed=None, fallback=None):
17 super().__init__()
18 self.typechanges = {}
19 self.fallback_writes = []
20 self.modified = set()
21 self.fixed = fixed or {}
22 self._fallback = {}
23 if fallback:
24 self.fallback = fallback
26 @property
27 def fallback(self):
28 return self._fallback
30 @fallback.setter
31 def fallback(self, newval):
32 ffkeys = set(self.fixed.keys()).intersection(set(newval.keys()))
33 for k in ffkeys:
34 if isinstance(self.fixed[k], DogmaticDict):
35 self.fixed[k].fallback = newval[k]
36 elif isinstance(self.fixed[k], dict):
37 self.fixed[k] = DogmaticDict(self.fixed[k])
38 self.fixed[k].fallback = newval[k]
40 self._fallback = newval
42 def _log_blocked_setitem(self, key, value, fixed_value):
43 if type_changed(value, fixed_value):
44 self.typechanges[key] = (type(value), type(fixed_value))
46 if is_different(value, fixed_value):
47 self.modified.add(key)
49 # if both are dicts recursively collect modified and typechanges
50 if isinstance(fixed_value, DogmaticDict) and isinstance(value, dict):
51 for k, val in fixed_value.typechanges.items():
52 self.typechanges[join_paths(key, k)] = val
54 self.modified |= {join_paths(key, m) for m in fixed_value.modified}
56 def __setitem__(self, key, value):
57 if key not in self.fixed:
58 if key in self.fallback:
59 self.fallback_writes.append(key)
60 return dict.__setitem__(self, key, value)
62 fixed_value = self.fixed[key]
63 dict.__setitem__(self, key, fixed_value)
64 # if both are dicts do a recursive update
65 if isinstance(fixed_value, DogmaticDict) and isinstance(value, dict):
66 for k, val in value.items():
67 fixed_value[k] = val
69 self._log_blocked_setitem(key, value, fixed_value)
71 def __getitem__(self, item):
72 if dict.__contains__(self, item):
73 return dict.__getitem__(self, item)
74 elif item in self.fallback:
75 if item in self.fixed:
76 return self.fixed[item]
77 else:
78 return self.fallback[item]
79 raise KeyError(item)
81 def __contains__(self, item):
82 return dict.__contains__(self, item) or (item in self.fallback)
84 def get(self, k, d=None):
85 if dict.__contains__(self, k):
86 return dict.__getitem__(self, k)
87 else:
88 return self.fallback.get(k, d)
90 def has_key(self, item):
91 return self.__contains__(item)
93 def __delitem__(self, key):
94 if key not in self.fixed:
95 dict.__delitem__(self, key)
97 def update(self, iterable=None, **kwargs):
98 if iterable is not None:
99 if hasattr(iterable, "keys"):
100 for key in iterable:
101 self[key] = iterable[key]
102 else:
103 for (key, value) in iterable:
104 self[key] = value
105 for key in kwargs:
106 self[key] = kwargs[key]
108 def revelation(self):
109 missing = set()
110 for key in self.fixed:
111 if not dict.__contains__(self, key):
112 self[key] = self.fixed[key]
113 missing.add(key)
115 if isinstance(self[key], (DogmaticDict, DogmaticList)):
116 missing |= {key + "." + k for k in self[key].revelation()}
117 return missing
120class DogmaticList(list):
121 def append(self, p_object):
122 pass
124 def extend(self, iterable):
125 pass
127 def insert(self, index, p_object):
128 pass
130 def reverse(self):
131 pass
133 def sort(self, compare=None, key=None, reverse=False):
134 pass
136 def __iadd__(self, other):
137 return self
139 def __imul__(self, other):
140 return self
142 def __setitem__(self, key, value):
143 pass
145 def __setslice__(self, i, j, sequence):
146 pass
148 def __delitem__(self, key):
149 pass
151 def __delslice__(self, i, j):
152 pass
154 def pop(self, index=None):
155 raise TypeError("Cannot pop from DogmaticList")
157 def remove(self, value):
158 pass
160 def revelation(self):
161 for obj in self:
162 if isinstance(obj, (DogmaticDict, DogmaticList)):
163 obj.revelation()
164 return set()
167class ReadOnlyContainer:
168 def __reduce__(self):
169 return self.__class__, (self.__copy__(),)
171 def _readonly(self, *args, **kwargs):
172 raise SacredError(
173 "The configuration is read-only in a captured function!",
174 filter_traceback="always",
175 )
178class ReadOnlyDict(ReadOnlyContainer, dict):
179 """A read-only variant of a `dict`."""
181 # Overwrite all methods that can modify a dict
182 clear = ReadOnlyContainer._readonly
183 pop = ReadOnlyContainer._readonly
184 popitem = ReadOnlyContainer._readonly
185 setdefault = ReadOnlyContainer._readonly
186 update = ReadOnlyContainer._readonly
187 __setitem__ = ReadOnlyContainer._readonly
188 __delitem__ = ReadOnlyContainer._readonly
190 def __copy__(self):
191 return {**self}
193 def __deepcopy__(self, memo):
194 d = dict(self)
195 return copy.deepcopy(d, memo=memo)
198class ReadOnlyList(ReadOnlyContainer, list):
199 """A read-only variant of a `list`."""
201 append = ReadOnlyContainer._readonly
202 clear = ReadOnlyContainer._readonly
203 extend = ReadOnlyContainer._readonly
204 insert = ReadOnlyContainer._readonly
205 pop = ReadOnlyContainer._readonly
206 remove = ReadOnlyContainer._readonly
207 reverse = ReadOnlyContainer._readonly
208 sort = ReadOnlyContainer._readonly
209 __setitem__ = ReadOnlyContainer._readonly
210 __delitem__ = ReadOnlyContainer._readonly
212 def __copy__(self):
213 return [*self]
215 def __deepcopy__(self, memo):
216 lst = list(self)
217 return copy.deepcopy(lst, memo=memo)
220def make_read_only(o):
221 """Makes objects read-only.
223 Converts every `list` and `dict` into `ReadOnlyList` and `ReadOnlyDict` in
224 a nested structure of `list`s, `dict`s and `tuple`s. Does not modify `o`
225 but returns the converted structure.
226 """
227 if type(o) == dict:
228 return ReadOnlyDict({k: make_read_only(v) for k, v in o.items()})
229 elif type(o) == list:
230 return ReadOnlyList([make_read_only(v) for v in o])
231 elif type(o) == tuple:
232 return tuple(map(make_read_only, o))
233 else:
234 return o
237if opt.has_yaml:
238 # Register read-only containers for yaml
239 def read_only_dict_representer(dumper, data):
240 """Saves `ReadOnlyDict` as `dict`."""
241 return dumper.represent_dict(data)
243 def read_only_list_representer(dumper, data):
244 """Saves `ReadOnlyList` as `list`."""
245 return dumper.represent_list(data)
247 opt.yaml.add_representer(ReadOnlyDict, read_only_dict_representer)
248 opt.yaml.add_representer(ReadOnlyList, read_only_list_representer)
249 opt.yaml.SafeDumper.add_representer(ReadOnlyDict, read_only_dict_representer)
250 opt.yaml.SafeDumper.add_representer(ReadOnlyList, read_only_list_representer)
253SIMPLIFY_TYPE = {
254 type(None): type(None),
255 bool: bool,
256 float: float,
257 int: int,
258 str: str,
259 list: list,
260 tuple: list,
261 dict: dict,
262 DogmaticDict: dict,
263 DogmaticList: list,
264}
266# if numpy is available we also want to ignore typechanges from numpy
267# datatypes to the corresponding python datatype
268if opt.has_numpy:
269 from sacred.optional import np
271 NP_FLOATS = ["float", "float16", "float32", "float64", "float128"]
272 for npf in NP_FLOATS:
273 if hasattr(np, npf):
274 SIMPLIFY_TYPE[getattr(np, npf)] = float
276 NP_INTS = [
277 "int",
278 "int8",
279 "int16",
280 "int32",
281 "int64",
282 "uint",
283 "uint8",
284 "uint16",
285 "uint32",
286 "uint64",
287 ]
288 for npi in NP_INTS:
289 if hasattr(np, npi):
290 SIMPLIFY_TYPE[getattr(np, npi)] = int
292 SIMPLIFY_TYPE[np.bool_] = bool
295def type_changed(old_value, new_value):
296 sot = SIMPLIFY_TYPE.get(type(old_value), type(old_value))
297 snt = SIMPLIFY_TYPE.get(type(new_value), type(new_value))
298 return sot != snt and old_value is not None # ignore typechanges from None
301def is_different(old_value, new_value):
302 """Numpy aware comparison between two values."""
303 if opt.has_numpy:
304 return not opt.np.array_equal(old_value, new_value)
305 else:
306 return old_value != new_value