Coverage for sacred/sacred/config/custom_containers.py: 43%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

190 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3import copy 

4 

5import sacred.optional as opt 

6from sacred.utils import join_paths, SacredError 

7 

8 

9def fallback_dict(fallback, **kwargs): 

10 fallback_copy = fallback.copy() 

11 fallback_copy.update(kwargs) 

12 return fallback_copy 

13 

14 

15class DogmaticDict(dict): 

16 def __init__(self, fixed=None, fallback=None): 

17 super().__init__() 

18 self.typechanges = {} 

19 self.fallback_writes = [] 

20 self.modified = set() 

21 self.fixed = fixed or {} 

22 self._fallback = {} 

23 if fallback: 

24 self.fallback = fallback 

25 

26 @property 

27 def fallback(self): 

28 return self._fallback 

29 

30 @fallback.setter 

31 def fallback(self, newval): 

32 ffkeys = set(self.fixed.keys()).intersection(set(newval.keys())) 

33 for k in ffkeys: 

34 if isinstance(self.fixed[k], DogmaticDict): 

35 self.fixed[k].fallback = newval[k] 

36 elif isinstance(self.fixed[k], dict): 

37 self.fixed[k] = DogmaticDict(self.fixed[k]) 

38 self.fixed[k].fallback = newval[k] 

39 

40 self._fallback = newval 

41 

42 def _log_blocked_setitem(self, key, value, fixed_value): 

43 if type_changed(value, fixed_value): 

44 self.typechanges[key] = (type(value), type(fixed_value)) 

45 

46 if is_different(value, fixed_value): 

47 self.modified.add(key) 

48 

49 # if both are dicts recursively collect modified and typechanges 

50 if isinstance(fixed_value, DogmaticDict) and isinstance(value, dict): 

51 for k, val in fixed_value.typechanges.items(): 

52 self.typechanges[join_paths(key, k)] = val 

53 

54 self.modified |= {join_paths(key, m) for m in fixed_value.modified} 

55 

56 def __setitem__(self, key, value): 

57 if key not in self.fixed: 

58 if key in self.fallback: 

59 self.fallback_writes.append(key) 

60 return dict.__setitem__(self, key, value) 

61 

62 fixed_value = self.fixed[key] 

63 dict.__setitem__(self, key, fixed_value) 

64 # if both are dicts do a recursive update 

65 if isinstance(fixed_value, DogmaticDict) and isinstance(value, dict): 

66 for k, val in value.items(): 

67 fixed_value[k] = val 

68 

69 self._log_blocked_setitem(key, value, fixed_value) 

70 

71 def __getitem__(self, item): 

72 if dict.__contains__(self, item): 

73 return dict.__getitem__(self, item) 

74 elif item in self.fallback: 

75 if item in self.fixed: 

76 return self.fixed[item] 

77 else: 

78 return self.fallback[item] 

79 raise KeyError(item) 

80 

81 def __contains__(self, item): 

82 return dict.__contains__(self, item) or (item in self.fallback) 

83 

84 def get(self, k, d=None): 

85 if dict.__contains__(self, k): 

86 return dict.__getitem__(self, k) 

87 else: 

88 return self.fallback.get(k, d) 

89 

90 def has_key(self, item): 

91 return self.__contains__(item) 

92 

93 def __delitem__(self, key): 

94 if key not in self.fixed: 

95 dict.__delitem__(self, key) 

96 

97 def update(self, iterable=None, **kwargs): 

98 if iterable is not None: 

99 if hasattr(iterable, "keys"): 

100 for key in iterable: 

101 self[key] = iterable[key] 

102 else: 

103 for (key, value) in iterable: 

104 self[key] = value 

105 for key in kwargs: 

106 self[key] = kwargs[key] 

107 

108 def revelation(self): 

109 missing = set() 

110 for key in self.fixed: 

111 if not dict.__contains__(self, key): 

112 self[key] = self.fixed[key] 

113 missing.add(key) 

114 

115 if isinstance(self[key], (DogmaticDict, DogmaticList)): 

116 missing |= {key + "." + k for k in self[key].revelation()} 

117 return missing 

118 

119 

120class DogmaticList(list): 

121 def append(self, p_object): 

122 pass 

123 

124 def extend(self, iterable): 

125 pass 

126 

127 def insert(self, index, p_object): 

128 pass 

129 

130 def reverse(self): 

131 pass 

132 

133 def sort(self, compare=None, key=None, reverse=False): 

134 pass 

135 

136 def __iadd__(self, other): 

137 return self 

138 

139 def __imul__(self, other): 

140 return self 

141 

142 def __setitem__(self, key, value): 

143 pass 

144 

145 def __setslice__(self, i, j, sequence): 

146 pass 

147 

148 def __delitem__(self, key): 

149 pass 

150 

151 def __delslice__(self, i, j): 

152 pass 

153 

154 def pop(self, index=None): 

155 raise TypeError("Cannot pop from DogmaticList") 

156 

157 def remove(self, value): 

158 pass 

159 

160 def revelation(self): 

161 for obj in self: 

162 if isinstance(obj, (DogmaticDict, DogmaticList)): 

163 obj.revelation() 

164 return set() 

165 

166 

167class ReadOnlyContainer: 

168 def __reduce__(self): 

169 return self.__class__, (self.__copy__(),) 

170 

171 def _readonly(self, *args, **kwargs): 

172 raise SacredError( 

173 "The configuration is read-only in a captured function!", 

174 filter_traceback="always", 

175 ) 

176 

177 

178class ReadOnlyDict(ReadOnlyContainer, dict): 

179 """A read-only variant of a `dict`.""" 

180 

181 # Overwrite all methods that can modify a dict 

182 clear = ReadOnlyContainer._readonly 

183 pop = ReadOnlyContainer._readonly 

184 popitem = ReadOnlyContainer._readonly 

185 setdefault = ReadOnlyContainer._readonly 

186 update = ReadOnlyContainer._readonly 

187 __setitem__ = ReadOnlyContainer._readonly 

188 __delitem__ = ReadOnlyContainer._readonly 

189 

190 def __copy__(self): 

191 return {**self} 

192 

193 def __deepcopy__(self, memo): 

194 d = dict(self) 

195 return copy.deepcopy(d, memo=memo) 

196 

197 

198class ReadOnlyList(ReadOnlyContainer, list): 

199 """A read-only variant of a `list`.""" 

200 

201 append = ReadOnlyContainer._readonly 

202 clear = ReadOnlyContainer._readonly 

203 extend = ReadOnlyContainer._readonly 

204 insert = ReadOnlyContainer._readonly 

205 pop = ReadOnlyContainer._readonly 

206 remove = ReadOnlyContainer._readonly 

207 reverse = ReadOnlyContainer._readonly 

208 sort = ReadOnlyContainer._readonly 

209 __setitem__ = ReadOnlyContainer._readonly 

210 __delitem__ = ReadOnlyContainer._readonly 

211 

212 def __copy__(self): 

213 return [*self] 

214 

215 def __deepcopy__(self, memo): 

216 lst = list(self) 

217 return copy.deepcopy(lst, memo=memo) 

218 

219 

220def make_read_only(o): 

221 """Makes objects read-only. 

222 

223 Converts every `list` and `dict` into `ReadOnlyList` and `ReadOnlyDict` in 

224 a nested structure of `list`s, `dict`s and `tuple`s. Does not modify `o` 

225 but returns the converted structure. 

226 """ 

227 if type(o) == dict: 

228 return ReadOnlyDict({k: make_read_only(v) for k, v in o.items()}) 

229 elif type(o) == list: 

230 return ReadOnlyList([make_read_only(v) for v in o]) 

231 elif type(o) == tuple: 

232 return tuple(map(make_read_only, o)) 

233 else: 

234 return o 

235 

236 

237if opt.has_yaml: 

238 # Register read-only containers for yaml 

239 def read_only_dict_representer(dumper, data): 

240 """Saves `ReadOnlyDict` as `dict`.""" 

241 return dumper.represent_dict(data) 

242 

243 def read_only_list_representer(dumper, data): 

244 """Saves `ReadOnlyList` as `list`.""" 

245 return dumper.represent_list(data) 

246 

247 opt.yaml.add_representer(ReadOnlyDict, read_only_dict_representer) 

248 opt.yaml.add_representer(ReadOnlyList, read_only_list_representer) 

249 opt.yaml.SafeDumper.add_representer(ReadOnlyDict, read_only_dict_representer) 

250 opt.yaml.SafeDumper.add_representer(ReadOnlyList, read_only_list_representer) 

251 

252 

253SIMPLIFY_TYPE = { 

254 type(None): type(None), 

255 bool: bool, 

256 float: float, 

257 int: int, 

258 str: str, 

259 list: list, 

260 tuple: list, 

261 dict: dict, 

262 DogmaticDict: dict, 

263 DogmaticList: list, 

264} 

265 

266# if numpy is available we also want to ignore typechanges from numpy 

267# datatypes to the corresponding python datatype 

268if opt.has_numpy: 

269 from sacred.optional import np 

270 

271 NP_FLOATS = ["float", "float16", "float32", "float64", "float128"] 

272 for npf in NP_FLOATS: 

273 if hasattr(np, npf): 

274 SIMPLIFY_TYPE[getattr(np, npf)] = float 

275 

276 NP_INTS = [ 

277 "int", 

278 "int8", 

279 "int16", 

280 "int32", 

281 "int64", 

282 "uint", 

283 "uint8", 

284 "uint16", 

285 "uint32", 

286 "uint64", 

287 ] 

288 for npi in NP_INTS: 

289 if hasattr(np, npi): 

290 SIMPLIFY_TYPE[getattr(np, npi)] = int 

291 

292 SIMPLIFY_TYPE[np.bool_] = bool 

293 

294 

295def type_changed(old_value, new_value): 

296 sot = SIMPLIFY_TYPE.get(type(old_value), type(old_value)) 

297 snt = SIMPLIFY_TYPE.get(type(new_value), type(new_value)) 

298 return sot != snt and old_value is not None # ignore typechanges from None 

299 

300 

301def is_different(old_value, new_value): 

302 """Numpy aware comparison between two values.""" 

303 if opt.has_numpy: 

304 return not opt.np.array_equal(old_value, new_value) 

305 else: 

306 return old_value != new_value