Coverage for sacred/sacred/ingredient.py: 99%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Generator, Tuple, Union
2import inspect
3import os.path
4from sacred.utils import PathType
5from typing import Sequence, Optional
7from collections import OrderedDict
9from sacred.config import (
10 ConfigDict,
11 ConfigScope,
12 create_captured_function,
13 load_config_file,
14)
15from sacred.dependencies import (
16 PEP440_VERSION_PATTERN,
17 PackageDependency,
18 Source,
19 gather_sources_and_dependencies,
20)
21from sacred.utils import CircularDependencyError, optional_kwargs_decorator, join_paths
23__all__ = ("Ingredient",)
26def collect_repositories(sources):
27 return [
28 {"url": s.repo, "commit": s.commit, "dirty": s.is_dirty}
29 for s in sources
30 if s.repo
31 ]
34class Ingredient:
35 """
36 Ingredients are reusable parts of experiments.
38 Each Ingredient can have its own configuration (visible as an entry in the
39 parents configuration), named configurations, captured functions and
40 commands.
42 Ingredients can themselves use ingredients.
43 """
45 def __init__(
46 self,
47 path: PathType,
48 ingredients: Sequence["Ingredient"] = (),
49 interactive: bool = False,
50 _caller_globals: Optional[dict] = None,
51 base_dir: Optional[PathType] = None,
52 save_git_info: bool = True,
53 ):
54 self.path = path
55 self.config_hooks = []
56 self.configurations = []
57 self.named_configs = dict()
58 self.ingredients = list(ingredients)
59 self.logger = None
60 self.captured_functions = []
61 self.post_run_hooks = []
62 self.pre_run_hooks = []
63 self._is_traversing = False
64 self.commands = OrderedDict()
65 # capture some context information
66 _caller_globals = _caller_globals or inspect.stack()[1][0].f_globals
67 mainfile_dir = os.path.dirname(_caller_globals.get("__file__", "."))
68 self.base_dir = os.path.abspath(base_dir or mainfile_dir)
69 self.save_git_info = save_git_info
70 self.doc = _caller_globals.get("__doc__", "")
71 (
72 self.mainfile,
73 self.sources,
74 self.dependencies,
75 ) = gather_sources_and_dependencies(
76 _caller_globals, save_git_info, self.base_dir
77 )
78 if self.mainfile is None and not interactive:
79 raise RuntimeError(
80 "Defining an experiment in interactive mode! "
81 "The sourcecode cannot be stored and the "
82 "experiment won't be reproducible. If you still"
83 " want to run it pass interactive=True"
84 )
86 # =========================== Decorators ==================================
87 @optional_kwargs_decorator
88 def capture(self, function=None, prefix=None):
89 """
90 Decorator to turn a function into a captured function.
92 The missing arguments of captured functions are automatically filled
93 from the configuration if possible.
94 See :ref:`captured_functions` for more information.
96 If a ``prefix`` is specified, the search for suitable
97 entries is performed in the corresponding subtree of the configuration.
98 """
99 if function in self.captured_functions:
100 return function
101 captured_function = create_captured_function(function, prefix=prefix)
102 self.captured_functions.append(captured_function)
103 return captured_function
105 @optional_kwargs_decorator
106 def pre_run_hook(self, func, prefix=None):
107 """
108 Decorator to add a pre-run hook to this ingredient.
110 Pre-run hooks are captured functions that are run, just before the
111 main function is executed.
112 """
113 cf = self.capture(func, prefix=prefix)
114 self.pre_run_hooks.append(cf)
115 return cf
117 @optional_kwargs_decorator
118 def post_run_hook(self, func, prefix=None):
119 """
120 Decorator to add a post-run hook to this ingredient.
122 Post-run hooks are captured functions that are run, just after the
123 main function is executed.
124 """
125 cf = self.capture(func, prefix=prefix)
126 self.post_run_hooks.append(cf)
127 return cf
129 @optional_kwargs_decorator
130 def command(self, function=None, prefix=None, unobserved=False):
131 """
132 Decorator to define a new command for this Ingredient or Experiment.
134 The name of the command will be the name of the function. It can be
135 called from the command-line or by using the run_command function.
137 Commands are automatically also captured functions.
139 The command can be given a prefix, to restrict its configuration space
140 to a subtree. (see ``capture`` for more information)
142 A command can be made unobserved (i.e. ignoring all observers) by
143 passing the unobserved=True keyword argument.
144 """
145 captured_f = self.capture(function, prefix=prefix)
146 captured_f.unobserved = unobserved
147 self.commands[function.__name__] = captured_f
148 return captured_f
150 def config(self, function):
151 """
152 Decorator to add a function to the configuration of the Experiment.
154 The decorated function is turned into a
155 :class:`~sacred.config_scope.ConfigScope` and added to the
156 Ingredient/Experiment.
158 When the experiment is run, this function will also be executed and
159 all json-serializable local variables inside it will end up as entries
160 in the configuration of the experiment.
161 """
162 self.configurations.append(ConfigScope(function))
163 return self.configurations[-1]
165 def named_config(self, func):
166 """
167 Decorator to turn a function into a named configuration.
169 See :ref:`named_configurations`.
170 """
171 config_scope = ConfigScope(func)
172 self._add_named_config(func.__name__, config_scope)
173 return config_scope
175 def config_hook(self, func):
176 """
177 Decorator to add a config hook to this ingredient.
179 Config hooks need to be a function that takes 3 parameters and returns
180 a dictionary:
181 (config, command_name, logger) --> dict
183 Config hooks are run after the configuration of this Ingredient, but
184 before any further ingredient-configurations are run.
185 The dictionary returned by a config hook is used to update the
186 config updates.
187 Note that they are not restricted to the local namespace of the
188 ingredient.
189 """
190 argspec = inspect.getfullargspec(func)
191 args = ["config", "command_name", "logger"]
192 if not (
193 argspec.args == args
194 and argspec.varargs is None
195 and not argspec.kwonlyargs
196 and argspec.defaults is None
197 ):
198 raise ValueError(
199 "Wrong signature for config_hook. Expected: "
200 "(config, command_name, logger)"
201 )
202 self.config_hooks.append(func)
203 return self.config_hooks[-1]
205 # =========================== Public Interface ============================
207 def add_config(self, cfg_or_file=None, **kw_conf):
208 """
209 Add a configuration entry to this ingredient/experiment.
211 Can be called with a filename, a dictionary xor with keyword arguments.
212 Supported formats for the config-file so far are: ``json``, ``pickle``
213 and ``yaml``.
215 The resulting dictionary will be converted into a
216 :class:`~sacred.config_scope.ConfigDict`.
218 :param cfg_or_file: Configuration dictionary of filename of config file
219 to add to this ingredient/experiment.
220 :type cfg_or_file: dict or str
221 :param kw_conf: Configuration entries to be added to this
222 ingredient/experiment.
223 """
224 self.configurations.append(self._create_config_dict(cfg_or_file, kw_conf))
226 def _add_named_config(self, name, conf):
227 if name in self.named_configs:
228 raise KeyError('Configuration name "{}" already in use!'.format(name))
229 self.named_configs[name] = conf
231 @staticmethod
232 def _create_config_dict(cfg_or_file, kw_conf):
233 if cfg_or_file is not None and kw_conf:
234 raise ValueError(
235 "cannot combine keyword config with " "positional argument"
236 )
237 if cfg_or_file is None:
238 if not kw_conf:
239 raise ValueError("attempted to add empty config")
240 return ConfigDict(kw_conf)
241 elif isinstance(cfg_or_file, dict):
242 return ConfigDict(cfg_or_file)
243 elif isinstance(cfg_or_file, str):
244 if not os.path.exists(cfg_or_file):
245 raise OSError("File not found {}".format(cfg_or_file))
246 abspath = os.path.abspath(cfg_or_file)
247 return ConfigDict(load_config_file(abspath))
248 else:
249 raise TypeError("Invalid argument type {}".format(type(cfg_or_file)))
251 def add_named_config(self, name, cfg_or_file=None, **kw_conf):
252 """
253 Add a **named** configuration entry to this ingredient/experiment.
255 Can be called with a filename, a dictionary xor with keyword arguments.
256 Supported formats for the config-file so far are: ``json``, ``pickle``
257 and ``yaml``.
259 The resulting dictionary will be converted into a
260 :class:`~sacred.config_scope.ConfigDict`.
262 See :ref:`named_configurations`
264 :param name: name of the configuration
265 :type name: str
266 :param cfg_or_file: Configuration dictionary of filename of config file
267 to add to this ingredient/experiment.
268 :type cfg_or_file: dict or str
269 :param kw_conf: Configuration entries to be added to this
270 ingredient/experiment.
271 """
272 self._add_named_config(name, self._create_config_dict(cfg_or_file, kw_conf))
274 def add_source_file(self, filename):
275 """
276 Add a file as source dependency to this experiment/ingredient.
278 :param filename: filename of the source to be added as dependency
279 :type filename: str
280 """
281 self.sources.add(Source.create(filename, self.save_git_info))
283 def add_package_dependency(self, package_name, version):
284 """
285 Add a package to the list of dependencies.
287 :param package_name: The name of the package dependency
288 :type package_name: str
289 :param version: The (minimum) version of the package
290 :type version: str
291 """
292 if not PEP440_VERSION_PATTERN.match(version):
293 raise ValueError('Invalid Version: "{}"'.format(version))
294 self.dependencies.add(PackageDependency(package_name, version))
296 def post_process_name(self, name, ingredient):
297 """Can be overridden to change the command name."""
298 return name
300 def gather_commands(self):
301 """Collect all commands from this ingredient and its sub-ingredients.
303 Yields
304 ------
305 cmd_name: str
306 The full (dotted) name of the command.
307 cmd: function
308 The corresponding captured function.
309 """
310 for ingredient, _ in self.traverse_ingredients():
311 for command_name, command in ingredient.commands.items():
312 cmd_name = join_paths(ingredient.path, command_name)
313 cmd_name = self.post_process_name(cmd_name, ingredient)
314 yield cmd_name, command
316 def gather_named_configs(
317 self,
318 ) -> Generator[Tuple[str, Union[ConfigScope, ConfigDict, str]], None, None]:
319 """Collect all named configs from this ingredient and its sub-ingredients.
321 Yields
322 ------
323 config_name
324 The full (dotted) name of the named config.
325 config
326 The corresponding named config.
327 """
328 for ingredient, _ in self.traverse_ingredients():
329 for config_name, config in ingredient.named_configs.items():
330 config_name = join_paths(ingredient.path, config_name)
331 config_name = self.post_process_name(config_name, ingredient)
332 yield config_name, config
334 def get_experiment_info(self):
335 """Get a dictionary with information about this experiment.
337 Contains:
338 * *name*: the name
339 * *sources*: a list of sources (filename, md5)
340 * *dependencies*: a list of package dependencies (name, version)
342 :return: experiment information
343 :rtype: dict
344 """
345 dependencies = set()
346 sources = set()
347 for ing, _ in self.traverse_ingredients():
348 dependencies |= ing.dependencies
349 sources |= ing.sources
351 for dep in dependencies:
352 dep.fill_missing_version()
354 mainfile = self.mainfile.to_json(self.base_dir)[0] if self.mainfile else None
356 def name_lower(d):
357 return d.name.lower()
359 return dict(
360 name=self.path,
361 base_dir=self.base_dir,
362 sources=[s.to_json(self.base_dir) for s in sorted(sources)],
363 dependencies=[d.to_json() for d in sorted(dependencies, key=name_lower)],
364 repositories=collect_repositories(sources),
365 mainfile=mainfile,
366 )
368 def traverse_ingredients(self):
369 """Recursively traverse this ingredient and its sub-ingredients.
371 Yields
372 ------
373 ingredient: sacred.Ingredient
374 The ingredient as traversed in preorder.
375 depth: int
376 The depth of the ingredient starting from 0.
378 Raises
379 ------
380 CircularDependencyError:
381 If a circular structure among ingredients was detected.
382 """
383 if self._is_traversing:
384 raise CircularDependencyError(ingredients=[self])
385 else:
386 self._is_traversing = True
387 yield self, 0
388 with CircularDependencyError.track(self):
389 for ingredient in self.ingredients:
390 for ingred, depth in ingredient.traverse_ingredients():
391 yield ingred, depth + 1
392 self._is_traversing = False