Coverage for sacred/sacred/initialize.py: 90%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

259 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import os 

5from collections import OrderedDict, defaultdict 

6from copy import copy, deepcopy 

7 

8from sacred.config import ( 

9 ConfigDict, 

10 chain_evaluate_config_scopes, 

11 dogmatize, 

12 load_config_file, 

13 undogmatize, 

14) 

15from sacred.config.config_summary import ConfigSummary 

16from sacred.config.custom_containers import make_read_only 

17from sacred.host_info import get_host_info 

18from sacred.randomness import create_rnd, get_seed 

19from sacred.run import Run 

20from sacred.utils import ( 

21 convert_to_nested_dict, 

22 create_basic_stream_logger, 

23 get_by_dotted_path, 

24 is_prefix, 

25 rel_path, 

26 iterate_flattened, 

27 set_by_dotted_path, 

28 recursive_update, 

29 iter_prefixes, 

30 join_paths, 

31 NamedConfigNotFoundError, 

32 ConfigAddedError, 

33) 

34from sacred.settings import SETTINGS 

35 

36 

37class Scaffold: 

38 def __init__( 

39 self, 

40 config_scopes, 

41 subrunners, 

42 path, 

43 captured_functions, 

44 commands, 

45 named_configs, 

46 config_hooks, 

47 generate_seed, 

48 ): 

49 self.config_scopes = config_scopes 

50 self.named_configs = named_configs 

51 self.subrunners = subrunners 

52 self.path = path 

53 self.generate_seed = generate_seed 

54 self.config_hooks = config_hooks 

55 self.config_updates = {} 

56 self.named_configs_to_use = [] 

57 self.config = {} 

58 self.fallback = None 

59 self.presets = {} 

60 self.fixture = None # TODO: rename 

61 self.logger = None 

62 self.seed = None 

63 self.rnd = None 

64 self._captured_functions = captured_functions 

65 self.commands = commands 

66 self.config_mods = None 

67 self.summaries = [] 

68 self.captured_args = { 

69 join_paths(cf.prefix, n) 

70 for cf in self._captured_functions 

71 for n in cf.signature.arguments 

72 } 

73 self.captured_args.add("__doc__") # allow setting the config docstring 

74 

75 def set_up_seed(self, rnd=None): 

76 if self.seed is not None: 

77 return 

78 

79 self.seed = self.config.get("seed") 

80 if self.seed is None: 

81 self.seed = get_seed(rnd) 

82 

83 self.rnd = create_rnd(self.seed) 

84 

85 if self.generate_seed: 

86 self.config["seed"] = self.seed 

87 

88 if "seed" in self.config and "seed" in self.config_mods.added: 

89 self.config_mods.modified.add("seed") 

90 self.config_mods.added -= {"seed"} 

91 

92 # Hierarchically set the seed of proper subrunners 

93 for subrunner_path, subrunner in reversed(list(self.subrunners.items())): 

94 if is_prefix(self.path, subrunner_path): 

95 subrunner.set_up_seed(self.rnd) 

96 

97 def gather_fallbacks(self): 

98 fallback = {"_log": self.logger} 

99 for sr_path, subrunner in self.subrunners.items(): 

100 if self.path and is_prefix(self.path, sr_path): 

101 path = sr_path[len(self.path) :].strip(".") 

102 set_by_dotted_path(fallback, path, subrunner.config) 

103 else: 

104 set_by_dotted_path(fallback, sr_path, subrunner.config) 

105 

106 # dogmatize to make the subrunner configurations read-only 

107 self.fallback = dogmatize(fallback) 

108 self.fallback.revelation() 

109 

110 def run_named_config(self, config_name): 

111 if os.path.isfile(config_name): 

112 nc = ConfigDict(load_config_file(config_name)) 

113 else: 

114 if config_name not in self.named_configs: 

115 raise NamedConfigNotFoundError( 

116 named_config=config_name, 

117 available_named_configs=tuple(self.named_configs.keys()), 

118 ) 

119 nc = self.named_configs[config_name] 

120 

121 cfg = nc( 

122 fixed=self.get_config_updates_recursive(), 

123 preset=self.presets, 

124 fallback=self.fallback, 

125 ) 

126 

127 return undogmatize(cfg) 

128 

129 def set_up_config(self): 

130 self.config, self.summaries = chain_evaluate_config_scopes( 

131 self.config_scopes, 

132 fixed=self.config_updates, 

133 preset=self.config, 

134 fallback=self.fallback, 

135 ) 

136 

137 self.get_config_modifications() 

138 

139 def run_config_hooks(self, config, command_name, logger): 

140 final_cfg_updates = {} 

141 for ch in self.config_hooks: 

142 cfg_upup = ch(deepcopy(config), command_name, logger) 

143 if cfg_upup: 

144 recursive_update(final_cfg_updates, cfg_upup) 

145 recursive_update(final_cfg_updates, self.config_updates) 

146 return final_cfg_updates 

147 

148 def get_config_modifications(self): 

149 self.config_mods = ConfigSummary( 

150 added={key for key, value in iterate_flattened(self.config_updates)} 

151 ) 

152 for cfg_summary in self.summaries: 

153 self.config_mods.update_from(cfg_summary) 

154 

155 def get_config_updates_recursive(self): 

156 config_updates = self.config_updates.copy() 

157 for sr_path, subrunner in self.subrunners.items(): 

158 if not is_prefix(self.path, sr_path): 

159 continue 

160 update = subrunner.get_config_updates_recursive() 

161 if update: 

162 config_updates[rel_path(self.path, sr_path)] = update 

163 return config_updates 

164 

165 def get_fixture(self): 

166 if self.fixture is not None: 

167 return self.fixture 

168 

169 def get_fixture_recursive(runner): 

170 for sr_path, subrunner in runner.subrunners.items(): 

171 # I am not sure if it is necessary to trigger all 

172 subrunner.get_fixture() 

173 get_fixture_recursive(subrunner) 

174 sub_fix = copy(subrunner.config) 

175 sub_path = sr_path 

176 if is_prefix(self.path, sub_path): 

177 sub_path = sr_path[len(self.path) :].strip(".") 

178 # Note: This might fail if we allow non-dict fixtures 

179 set_by_dotted_path(self.fixture, sub_path, sub_fix) 

180 

181 self.fixture = copy(self.config) 

182 get_fixture_recursive(self) 

183 

184 return self.fixture 

185 

186 def finalize_initialization(self, run): 

187 # look at seed again, because it might have changed during the 

188 # configuration process 

189 if "seed" in self.config: 

190 self.seed = self.config["seed"] 

191 self.rnd = create_rnd(self.seed) 

192 

193 for cfunc in self._captured_functions: 

194 # Setup the captured function 

195 cfunc.logger = self.logger.getChild(cfunc.__name__) 

196 seed = get_seed(self.rnd) 

197 cfunc.rnd = create_rnd(seed) 

198 cfunc.run = run 

199 cfunc.config = get_by_dotted_path( 

200 self.get_fixture(), cfunc.prefix, default={} 

201 ) 

202 

203 # Make configuration read only if enabled in settings 

204 if SETTINGS.CONFIG.READ_ONLY_CONFIG: 

205 cfunc.config = make_read_only(cfunc.config) 

206 

207 if not run.force: 

208 self._warn_about_suspicious_changes() 

209 

210 def _warn_about_suspicious_changes(self): 

211 for add in sorted(self.config_mods.added): 

212 if not set(iter_prefixes(add)).intersection(self.captured_args): 

213 if self.path: 

214 add = join_paths(self.path, add) 

215 raise ConfigAddedError(add, config=self.config) 

216 else: 

217 self.logger.warning('Added new config entry: "%s"' % add) 

218 

219 for key, (type_old, type_new) in self.config_mods.typechanged.items(): 

220 if type_old in (int, float) and type_new in (int, float): 

221 continue 

222 self.logger.warning( 

223 'Changed type of config entry "%s" from %s to %s' 

224 % (key, type_old.__name__, type_new.__name__) 

225 ) 

226 

227 for cfg_summary in self.summaries: 

228 for key in cfg_summary.ignored_fallbacks: 

229 self.logger.warning( 

230 'Ignored attempt to set value of "%s", because it is an ' 

231 "ingredient." % key 

232 ) 

233 

234 def __repr__(self): 

235 return "<Scaffold: '{}'>".format(self.path) 

236 

237 

238def get_configuration(scaffolding): 

239 config = {} 

240 for sc_path, scaffold in reversed(list(scaffolding.items())): 

241 if not scaffold.config: 

242 continue 

243 if sc_path: 

244 set_by_dotted_path(config, sc_path, scaffold.config) 

245 else: 

246 config.update(scaffold.config) 

247 return config 

248 

249 

250def distribute_named_configs(scaffolding, named_configs): 

251 for ncfg in named_configs: 

252 if os.path.exists(ncfg): 

253 scaffolding[""].use_named_config(ncfg) 

254 else: 

255 path, _, cfg_name = ncfg.rpartition(".") 

256 if path not in scaffolding: 

257 raise KeyError( 

258 'Ingredient for named config "{}" not found'.format(ncfg) 

259 ) 

260 scaffolding[path].use_named_config(cfg_name) 

261 

262 

263def initialize_logging(experiment, scaffolding, log_level=None): 

264 if experiment.logger is None: 

265 root_logger = create_basic_stream_logger() 

266 else: 

267 root_logger = experiment.logger 

268 

269 for sc_path, scaffold in scaffolding.items(): 

270 if sc_path: 

271 scaffold.logger = root_logger.getChild(sc_path) 

272 else: 

273 scaffold.logger = root_logger 

274 

275 # set log level 

276 if log_level is not None: 

277 try: 

278 lvl = int(log_level) 

279 except ValueError: 

280 lvl = log_level 

281 root_logger.setLevel(lvl) 

282 

283 return root_logger, root_logger.getChild(experiment.path) 

284 

285 

286def create_scaffolding(experiment, sorted_ingredients): 

287 scaffolding = OrderedDict() 

288 for ingredient in sorted_ingredients[:-1]: 

289 scaffolding[ingredient] = Scaffold( 

290 config_scopes=ingredient.configurations, 

291 subrunners=OrderedDict( 

292 [(scaffolding[m].path, scaffolding[m]) for m in ingredient.ingredients] 

293 ), 

294 path=ingredient.path, 

295 captured_functions=ingredient.captured_functions, 

296 commands=ingredient.commands, 

297 named_configs=ingredient.named_configs, 

298 config_hooks=ingredient.config_hooks, 

299 generate_seed=False, 

300 ) 

301 

302 scaffolding[experiment] = Scaffold( 

303 experiment.configurations, 

304 subrunners=OrderedDict( 

305 [(scaffolding[m].path, scaffolding[m]) for m in experiment.ingredients] 

306 ), 

307 path="", 

308 captured_functions=experiment.captured_functions, 

309 commands=experiment.commands, 

310 named_configs=experiment.named_configs, 

311 config_hooks=experiment.config_hooks, 

312 generate_seed=True, 

313 ) 

314 

315 scaffolding_ret = OrderedDict([(sc.path, sc) for sc in scaffolding.values()]) 

316 if len(scaffolding_ret) != len(scaffolding): 

317 raise ValueError( 

318 "The pathes of the ingredients are not unique. " 

319 "{}".format([s.path for s in scaffolding]) 

320 ) 

321 

322 return scaffolding_ret 

323 

324 

325def gather_ingredients_topological(ingredient): 

326 sub_ingredients = defaultdict(int) 

327 for sub_ing, depth in ingredient.traverse_ingredients(): 

328 sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth) 

329 return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x]) 

330 

331 

332def get_config_modifications(scaffolding): 

333 config_modifications = ConfigSummary() 

334 for sc_path, scaffold in scaffolding.items(): 

335 config_modifications.update_add(scaffold.config_mods, path=sc_path) 

336 return config_modifications 

337 

338 

339def get_command(scaffolding, command_path): 

340 path, _, command_name = command_path.rpartition(".") 

341 if path not in scaffolding: 

342 raise KeyError('Ingredient for command "%s" not found.' % command_path) 

343 

344 if command_name in scaffolding[path].commands: 

345 return scaffolding[path].commands[command_name] 

346 else: 

347 if path: 

348 raise KeyError( 

349 'Command "%s" not found in ingredient "%s"' % (command_name, path) 

350 ) 

351 else: 

352 raise KeyError('Command "%s" not found' % command_name) 

353 

354 

355def find_best_match(path, prefixes): 

356 """Find the Ingredient that shares the longest prefix with path.""" 

357 path_parts = path.split(".") 

358 for p in prefixes: 

359 if len(p) <= len(path_parts) and p == path_parts[: len(p)]: 

360 return ".".join(p), ".".join(path_parts[len(p) :]) 

361 return "", path 

362 

363 

364def distribute_presets(sc_path, prefixes, scaffolding, config_updates): 

365 for path, value in iterate_flattened(config_updates): 

366 if sc_path: 

367 path = sc_path + "." + path 

368 scaffold_name, suffix = find_best_match(path, prefixes) 

369 scaff = scaffolding[scaffold_name] 

370 set_by_dotted_path(scaff.presets, suffix, value) 

371 

372 

373def distribute_config_updates(prefixes, scaffolding, config_updates): 

374 for path, value in iterate_flattened(config_updates): 

375 scaffold_name, suffix = find_best_match(path, prefixes) 

376 scaff = scaffolding[scaffold_name] 

377 set_by_dotted_path(scaff.config_updates, suffix, value) 

378 

379 

380def get_scaffolding_and_config_name(named_config, scaffolding): 

381 if os.path.exists(named_config): 

382 path, cfg_name = "", named_config 

383 else: 

384 path, _, cfg_name = named_config.rpartition(".") 

385 

386 if path not in scaffolding: 

387 raise KeyError( 

388 'Ingredient for named config "{}" not found'.format(named_config) 

389 ) 

390 scaff = scaffolding[path] 

391 return scaff, cfg_name 

392 

393 

394def create_run( 

395 experiment, 

396 command_name, 

397 config_updates=None, 

398 named_configs=(), 

399 force=False, 

400 log_level=None, 

401): 

402 

403 sorted_ingredients = gather_ingredients_topological(experiment) 

404 scaffolding = create_scaffolding(experiment, sorted_ingredients) 

405 # get all split non-empty prefixes sorted from deepest to shallowest 

406 prefixes = sorted( 

407 [s.split(".") for s in scaffolding if s != ""], 

408 reverse=True, 

409 key=lambda p: len(p), 

410 ) 

411 

412 # --------- configuration process ------------------- 

413 

414 # Phase 1: Config updates 

415 config_updates = config_updates or {} 

416 config_updates = convert_to_nested_dict(config_updates) 

417 root_logger, run_logger = initialize_logging(experiment, scaffolding, log_level) 

418 distribute_config_updates(prefixes, scaffolding, config_updates) 

419 

420 # Phase 2: Named Configs 

421 for ncfg in named_configs: 

422 scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) 

423 scaff.gather_fallbacks() 

424 ncfg_updates = scaff.run_named_config(cfg_name) 

425 distribute_presets(scaff.path, prefixes, scaffolding, ncfg_updates) 

426 for ncfg_key, value in iterate_flattened(ncfg_updates): 

427 set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value) 

428 

429 distribute_config_updates(prefixes, scaffolding, config_updates) 

430 

431 # Phase 3: Normal config scopes 

432 for scaffold in scaffolding.values(): 

433 scaffold.gather_fallbacks() 

434 scaffold.set_up_config() 

435 

436 # update global config 

437 config = get_configuration(scaffolding) 

438 # run config hooks 

439 config_hook_updates = scaffold.run_config_hooks( 

440 config, command_name, run_logger 

441 ) 

442 recursive_update(scaffold.config, config_hook_updates) 

443 

444 # Phase 4: finalize seeding 

445 for scaffold in reversed(list(scaffolding.values())): 

446 scaffold.set_up_seed() # partially recursive 

447 

448 config = get_configuration(scaffolding) 

449 config_modifications = get_config_modifications(scaffolding) 

450 

451 # ---------------------------------------------------- 

452 

453 experiment_info = experiment.get_experiment_info() 

454 host_info = get_host_info(experiment.additional_host_info) 

455 main_function = get_command(scaffolding, command_name) 

456 pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] 

457 post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] 

458 

459 run = Run( 

460 config, 

461 config_modifications, 

462 main_function, 

463 copy(experiment.observers), 

464 root_logger, 

465 run_logger, 

466 experiment_info, 

467 host_info, 

468 pre_runs, 

469 post_runs, 

470 experiment.captured_out_filter, 

471 ) 

472 

473 if hasattr(main_function, "unobserved"): 

474 run.unobserved = main_function.unobserved 

475 

476 run.force = force 

477 

478 for scaffold in scaffolding.values(): 

479 scaffold.finalize_initialization(run=run) 

480 

481 return run