Coverage for sacred/sacred/initialize.py: 90%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import os
5from collections import OrderedDict, defaultdict
6from copy import copy, deepcopy
8from sacred.config import (
9 ConfigDict,
10 chain_evaluate_config_scopes,
11 dogmatize,
12 load_config_file,
13 undogmatize,
14)
15from sacred.config.config_summary import ConfigSummary
16from sacred.config.custom_containers import make_read_only
17from sacred.host_info import get_host_info
18from sacred.randomness import create_rnd, get_seed
19from sacred.run import Run
20from sacred.utils import (
21 convert_to_nested_dict,
22 create_basic_stream_logger,
23 get_by_dotted_path,
24 is_prefix,
25 rel_path,
26 iterate_flattened,
27 set_by_dotted_path,
28 recursive_update,
29 iter_prefixes,
30 join_paths,
31 NamedConfigNotFoundError,
32 ConfigAddedError,
33)
34from sacred.settings import SETTINGS
37class Scaffold:
38 def __init__(
39 self,
40 config_scopes,
41 subrunners,
42 path,
43 captured_functions,
44 commands,
45 named_configs,
46 config_hooks,
47 generate_seed,
48 ):
49 self.config_scopes = config_scopes
50 self.named_configs = named_configs
51 self.subrunners = subrunners
52 self.path = path
53 self.generate_seed = generate_seed
54 self.config_hooks = config_hooks
55 self.config_updates = {}
56 self.named_configs_to_use = []
57 self.config = {}
58 self.fallback = None
59 self.presets = {}
60 self.fixture = None # TODO: rename
61 self.logger = None
62 self.seed = None
63 self.rnd = None
64 self._captured_functions = captured_functions
65 self.commands = commands
66 self.config_mods = None
67 self.summaries = []
68 self.captured_args = {
69 join_paths(cf.prefix, n)
70 for cf in self._captured_functions
71 for n in cf.signature.arguments
72 }
73 self.captured_args.add("__doc__") # allow setting the config docstring
75 def set_up_seed(self, rnd=None):
76 if self.seed is not None:
77 return
79 self.seed = self.config.get("seed")
80 if self.seed is None:
81 self.seed = get_seed(rnd)
83 self.rnd = create_rnd(self.seed)
85 if self.generate_seed:
86 self.config["seed"] = self.seed
88 if "seed" in self.config and "seed" in self.config_mods.added:
89 self.config_mods.modified.add("seed")
90 self.config_mods.added -= {"seed"}
92 # Hierarchically set the seed of proper subrunners
93 for subrunner_path, subrunner in reversed(list(self.subrunners.items())):
94 if is_prefix(self.path, subrunner_path):
95 subrunner.set_up_seed(self.rnd)
97 def gather_fallbacks(self):
98 fallback = {"_log": self.logger}
99 for sr_path, subrunner in self.subrunners.items():
100 if self.path and is_prefix(self.path, sr_path):
101 path = sr_path[len(self.path) :].strip(".")
102 set_by_dotted_path(fallback, path, subrunner.config)
103 else:
104 set_by_dotted_path(fallback, sr_path, subrunner.config)
106 # dogmatize to make the subrunner configurations read-only
107 self.fallback = dogmatize(fallback)
108 self.fallback.revelation()
110 def run_named_config(self, config_name):
111 if os.path.isfile(config_name):
112 nc = ConfigDict(load_config_file(config_name))
113 else:
114 if config_name not in self.named_configs:
115 raise NamedConfigNotFoundError(
116 named_config=config_name,
117 available_named_configs=tuple(self.named_configs.keys()),
118 )
119 nc = self.named_configs[config_name]
121 cfg = nc(
122 fixed=self.get_config_updates_recursive(),
123 preset=self.presets,
124 fallback=self.fallback,
125 )
127 return undogmatize(cfg)
129 def set_up_config(self):
130 self.config, self.summaries = chain_evaluate_config_scopes(
131 self.config_scopes,
132 fixed=self.config_updates,
133 preset=self.config,
134 fallback=self.fallback,
135 )
137 self.get_config_modifications()
139 def run_config_hooks(self, config, command_name, logger):
140 final_cfg_updates = {}
141 for ch in self.config_hooks:
142 cfg_upup = ch(deepcopy(config), command_name, logger)
143 if cfg_upup:
144 recursive_update(final_cfg_updates, cfg_upup)
145 recursive_update(final_cfg_updates, self.config_updates)
146 return final_cfg_updates
148 def get_config_modifications(self):
149 self.config_mods = ConfigSummary(
150 added={key for key, value in iterate_flattened(self.config_updates)}
151 )
152 for cfg_summary in self.summaries:
153 self.config_mods.update_from(cfg_summary)
155 def get_config_updates_recursive(self):
156 config_updates = self.config_updates.copy()
157 for sr_path, subrunner in self.subrunners.items():
158 if not is_prefix(self.path, sr_path):
159 continue
160 update = subrunner.get_config_updates_recursive()
161 if update:
162 config_updates[rel_path(self.path, sr_path)] = update
163 return config_updates
165 def get_fixture(self):
166 if self.fixture is not None:
167 return self.fixture
169 def get_fixture_recursive(runner):
170 for sr_path, subrunner in runner.subrunners.items():
171 # I am not sure if it is necessary to trigger all
172 subrunner.get_fixture()
173 get_fixture_recursive(subrunner)
174 sub_fix = copy(subrunner.config)
175 sub_path = sr_path
176 if is_prefix(self.path, sub_path):
177 sub_path = sr_path[len(self.path) :].strip(".")
178 # Note: This might fail if we allow non-dict fixtures
179 set_by_dotted_path(self.fixture, sub_path, sub_fix)
181 self.fixture = copy(self.config)
182 get_fixture_recursive(self)
184 return self.fixture
186 def finalize_initialization(self, run):
187 # look at seed again, because it might have changed during the
188 # configuration process
189 if "seed" in self.config:
190 self.seed = self.config["seed"]
191 self.rnd = create_rnd(self.seed)
193 for cfunc in self._captured_functions:
194 # Setup the captured function
195 cfunc.logger = self.logger.getChild(cfunc.__name__)
196 seed = get_seed(self.rnd)
197 cfunc.rnd = create_rnd(seed)
198 cfunc.run = run
199 cfunc.config = get_by_dotted_path(
200 self.get_fixture(), cfunc.prefix, default={}
201 )
203 # Make configuration read only if enabled in settings
204 if SETTINGS.CONFIG.READ_ONLY_CONFIG:
205 cfunc.config = make_read_only(cfunc.config)
207 if not run.force:
208 self._warn_about_suspicious_changes()
210 def _warn_about_suspicious_changes(self):
211 for add in sorted(self.config_mods.added):
212 if not set(iter_prefixes(add)).intersection(self.captured_args):
213 if self.path:
214 add = join_paths(self.path, add)
215 raise ConfigAddedError(add, config=self.config)
216 else:
217 self.logger.warning('Added new config entry: "%s"' % add)
219 for key, (type_old, type_new) in self.config_mods.typechanged.items():
220 if type_old in (int, float) and type_new in (int, float):
221 continue
222 self.logger.warning(
223 'Changed type of config entry "%s" from %s to %s'
224 % (key, type_old.__name__, type_new.__name__)
225 )
227 for cfg_summary in self.summaries:
228 for key in cfg_summary.ignored_fallbacks:
229 self.logger.warning(
230 'Ignored attempt to set value of "%s", because it is an '
231 "ingredient." % key
232 )
234 def __repr__(self):
235 return "<Scaffold: '{}'>".format(self.path)
238def get_configuration(scaffolding):
239 config = {}
240 for sc_path, scaffold in reversed(list(scaffolding.items())):
241 if not scaffold.config:
242 continue
243 if sc_path:
244 set_by_dotted_path(config, sc_path, scaffold.config)
245 else:
246 config.update(scaffold.config)
247 return config
250def distribute_named_configs(scaffolding, named_configs):
251 for ncfg in named_configs:
252 if os.path.exists(ncfg):
253 scaffolding[""].use_named_config(ncfg)
254 else:
255 path, _, cfg_name = ncfg.rpartition(".")
256 if path not in scaffolding:
257 raise KeyError(
258 'Ingredient for named config "{}" not found'.format(ncfg)
259 )
260 scaffolding[path].use_named_config(cfg_name)
263def initialize_logging(experiment, scaffolding, log_level=None):
264 if experiment.logger is None:
265 root_logger = create_basic_stream_logger()
266 else:
267 root_logger = experiment.logger
269 for sc_path, scaffold in scaffolding.items():
270 if sc_path:
271 scaffold.logger = root_logger.getChild(sc_path)
272 else:
273 scaffold.logger = root_logger
275 # set log level
276 if log_level is not None:
277 try:
278 lvl = int(log_level)
279 except ValueError:
280 lvl = log_level
281 root_logger.setLevel(lvl)
283 return root_logger, root_logger.getChild(experiment.path)
286def create_scaffolding(experiment, sorted_ingredients):
287 scaffolding = OrderedDict()
288 for ingredient in sorted_ingredients[:-1]:
289 scaffolding[ingredient] = Scaffold(
290 config_scopes=ingredient.configurations,
291 subrunners=OrderedDict(
292 [(scaffolding[m].path, scaffolding[m]) for m in ingredient.ingredients]
293 ),
294 path=ingredient.path,
295 captured_functions=ingredient.captured_functions,
296 commands=ingredient.commands,
297 named_configs=ingredient.named_configs,
298 config_hooks=ingredient.config_hooks,
299 generate_seed=False,
300 )
302 scaffolding[experiment] = Scaffold(
303 experiment.configurations,
304 subrunners=OrderedDict(
305 [(scaffolding[m].path, scaffolding[m]) for m in experiment.ingredients]
306 ),
307 path="",
308 captured_functions=experiment.captured_functions,
309 commands=experiment.commands,
310 named_configs=experiment.named_configs,
311 config_hooks=experiment.config_hooks,
312 generate_seed=True,
313 )
315 scaffolding_ret = OrderedDict([(sc.path, sc) for sc in scaffolding.values()])
316 if len(scaffolding_ret) != len(scaffolding):
317 raise ValueError(
318 "The pathes of the ingredients are not unique. "
319 "{}".format([s.path for s in scaffolding])
320 )
322 return scaffolding_ret
325def gather_ingredients_topological(ingredient):
326 sub_ingredients = defaultdict(int)
327 for sub_ing, depth in ingredient.traverse_ingredients():
328 sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth)
329 return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x])
332def get_config_modifications(scaffolding):
333 config_modifications = ConfigSummary()
334 for sc_path, scaffold in scaffolding.items():
335 config_modifications.update_add(scaffold.config_mods, path=sc_path)
336 return config_modifications
339def get_command(scaffolding, command_path):
340 path, _, command_name = command_path.rpartition(".")
341 if path not in scaffolding:
342 raise KeyError('Ingredient for command "%s" not found.' % command_path)
344 if command_name in scaffolding[path].commands:
345 return scaffolding[path].commands[command_name]
346 else:
347 if path:
348 raise KeyError(
349 'Command "%s" not found in ingredient "%s"' % (command_name, path)
350 )
351 else:
352 raise KeyError('Command "%s" not found' % command_name)
355def find_best_match(path, prefixes):
356 """Find the Ingredient that shares the longest prefix with path."""
357 path_parts = path.split(".")
358 for p in prefixes:
359 if len(p) <= len(path_parts) and p == path_parts[: len(p)]:
360 return ".".join(p), ".".join(path_parts[len(p) :])
361 return "", path
364def distribute_presets(sc_path, prefixes, scaffolding, config_updates):
365 for path, value in iterate_flattened(config_updates):
366 if sc_path:
367 path = sc_path + "." + path
368 scaffold_name, suffix = find_best_match(path, prefixes)
369 scaff = scaffolding[scaffold_name]
370 set_by_dotted_path(scaff.presets, suffix, value)
373def distribute_config_updates(prefixes, scaffolding, config_updates):
374 for path, value in iterate_flattened(config_updates):
375 scaffold_name, suffix = find_best_match(path, prefixes)
376 scaff = scaffolding[scaffold_name]
377 set_by_dotted_path(scaff.config_updates, suffix, value)
380def get_scaffolding_and_config_name(named_config, scaffolding):
381 if os.path.exists(named_config):
382 path, cfg_name = "", named_config
383 else:
384 path, _, cfg_name = named_config.rpartition(".")
386 if path not in scaffolding:
387 raise KeyError(
388 'Ingredient for named config "{}" not found'.format(named_config)
389 )
390 scaff = scaffolding[path]
391 return scaff, cfg_name
394def create_run(
395 experiment,
396 command_name,
397 config_updates=None,
398 named_configs=(),
399 force=False,
400 log_level=None,
401):
403 sorted_ingredients = gather_ingredients_topological(experiment)
404 scaffolding = create_scaffolding(experiment, sorted_ingredients)
405 # get all split non-empty prefixes sorted from deepest to shallowest
406 prefixes = sorted(
407 [s.split(".") for s in scaffolding if s != ""],
408 reverse=True,
409 key=lambda p: len(p),
410 )
412 # --------- configuration process -------------------
414 # Phase 1: Config updates
415 config_updates = config_updates or {}
416 config_updates = convert_to_nested_dict(config_updates)
417 root_logger, run_logger = initialize_logging(experiment, scaffolding, log_level)
418 distribute_config_updates(prefixes, scaffolding, config_updates)
420 # Phase 2: Named Configs
421 for ncfg in named_configs:
422 scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
423 scaff.gather_fallbacks()
424 ncfg_updates = scaff.run_named_config(cfg_name)
425 distribute_presets(scaff.path, prefixes, scaffolding, ncfg_updates)
426 for ncfg_key, value in iterate_flattened(ncfg_updates):
427 set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value)
429 distribute_config_updates(prefixes, scaffolding, config_updates)
431 # Phase 3: Normal config scopes
432 for scaffold in scaffolding.values():
433 scaffold.gather_fallbacks()
434 scaffold.set_up_config()
436 # update global config
437 config = get_configuration(scaffolding)
438 # run config hooks
439 config_hook_updates = scaffold.run_config_hooks(
440 config, command_name, run_logger
441 )
442 recursive_update(scaffold.config, config_hook_updates)
444 # Phase 4: finalize seeding
445 for scaffold in reversed(list(scaffolding.values())):
446 scaffold.set_up_seed() # partially recursive
448 config = get_configuration(scaffolding)
449 config_modifications = get_config_modifications(scaffolding)
451 # ----------------------------------------------------
453 experiment_info = experiment.get_experiment_info()
454 host_info = get_host_info(experiment.additional_host_info)
455 main_function = get_command(scaffolding, command_name)
456 pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
457 post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]
459 run = Run(
460 config,
461 config_modifications,
462 main_function,
463 copy(experiment.observers),
464 root_logger,
465 run_logger,
466 experiment_info,
467 host_info,
468 pre_runs,
469 post_runs,
470 experiment.captured_out_filter,
471 )
473 if hasattr(main_function, "unobserved"):
474 run.unobserved = main_function.unobserved
476 run.force = force
478 for scaffold in scaffolding.values():
479 scaffold.finalize_initialization(run=run)
481 return run