Coverage for sacred/sacred/experiment.py: 76%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""The Experiment class, which is central to sacred."""
2import inspect
3import os.path
4import sys
5import warnings
6from collections import OrderedDict
7from typing import Sequence, Optional, List
9from docopt import docopt, printable_usage
11from sacred import SETTINGS
12from sacred.arg_parser import format_usage, get_config_updates
13from sacred import commandline_options
14from sacred.commandline_options import CLIOption
15from sacred.commands import (
16 help_for_command,
17 print_config,
18 print_dependencies,
19 save_config,
20 print_named_configs,
21)
22from sacred.observers.file_storage import file_storage_option
23from sacred.observers.s3_observer import s3_option
24from sacred.config.signature import Signature
25from sacred.ingredient import Ingredient
26from sacred.initialize import create_run
27from sacred.observers.sql import sql_option
28from sacred.observers.tinydb_hashfs import tiny_db_option
29from sacred.run import Run
30from sacred.host_info import check_additional_host_info, HostInfoGetter
31from sacred.utils import (
32 print_filtered_stacktrace,
33 ensure_wellformed_argv,
34 SacredError,
35 format_sacred_error,
36 PathType,
37 get_inheritors,
38)
39from sacred.observers.mongo import mongo_db_option
41__all__ = ("Experiment",)
44class Experiment(Ingredient):
45 """
46 The central class for each experiment in Sacred.
48 It manages the configuration, the main function, captured methods,
49 observers, commands, and further ingredients.
51 An Experiment instance should be created as one of the first
52 things in any experiment-file.
53 """
55 def __init__(
56 self,
57 name: Optional[str] = None,
58 ingredients: Sequence[Ingredient] = (),
59 interactive: bool = False,
60 base_dir: Optional[PathType] = None,
61 additional_host_info: Optional[List[HostInfoGetter]] = None,
62 additional_cli_options: Optional[Sequence[CLIOption]] = None,
63 save_git_info: bool = True,
64 ):
65 """
66 Create a new experiment with the given name and optional ingredients.
68 Parameters
69 ----------
70 name
71 Optional name of this experiment, defaults to the filename.
72 (Required in interactive mode)
74 ingredients : list[sacred.Ingredient], optional
75 A list of ingredients to be used with this experiment.
77 interactive
78 If set to True will allow the experiment to be run in interactive
79 mode (e.g. IPython or Jupyter notebooks).
80 However, this mode is discouraged since it won't allow storing the
81 source-code or reliable reproduction of the runs.
83 base_dir
84 Optional full path to the base directory of this experiment. This
85 will set the scope for automatic source file discovery.
87 additional_host_info
88 Optional dictionary containing as keys the names of the pieces of
89 host info you want to collect, and as
90 values the functions collecting those pieces of information.
92 save_git_info:
93 Optionally save the git commit hash and the git state
94 (clean or dirty) for all source files. This requires the GitPython
95 package.
96 """
97 self.additional_host_info = additional_host_info or []
98 check_additional_host_info(self.additional_host_info)
99 self.additional_cli_options = additional_cli_options or []
100 self.all_cli_options = (
101 gather_command_line_options() + self.additional_cli_options
102 )
103 caller_globals = inspect.stack()[1][0].f_globals
104 if name is None:
105 if interactive:
106 raise RuntimeError("name is required in interactive mode.")
107 mainfile = caller_globals.get("__file__")
108 if mainfile is None:
109 raise RuntimeError(
110 "No main-file found. Are you running in "
111 "interactive mode? If so please provide a "
112 "name and set interactive=True."
113 )
114 name = os.path.basename(mainfile)
115 if name.endswith(".py"):
116 name = name[:-3]
117 elif name.endswith(".pyc"):
118 name = name[:-4]
119 super().__init__(
120 path=name,
121 ingredients=ingredients,
122 interactive=interactive,
123 base_dir=base_dir,
124 _caller_globals=caller_globals,
125 save_git_info=save_git_info,
126 )
127 self.default_command = None
128 self.command(print_config, unobserved=True)
129 self.command(print_dependencies, unobserved=True)
130 self.command(save_config, unobserved=True)
131 self.command(print_named_configs(self), unobserved=True)
132 self.observers = []
133 self.current_run = None
134 self.captured_out_filter = None
135 """Filter function to be applied to captured output of a run"""
136 self.option_hooks = []
138 # =========================== Decorators ==================================
140 def main(self, function):
141 """
142 Decorator to define the main function of the experiment.
144 The main function of an experiment is the default command that is being
145 run when no command is specified, or when calling the run() method.
147 Usually it is more convenient to use ``automain`` instead.
148 """
149 captured = self.command(function)
150 self.default_command = captured.__name__
151 return captured
153 def automain(self, function):
154 """
155 Decorator that defines *and runs* the main function of the experiment.
157 The decorated function is marked as the default command for this
158 experiment, and the command-line interface is automatically run when
159 the file is executed.
161 The method decorated by this should be last in the file because is
162 equivalent to:
164 Example
165 -------
166 ::
168 @ex.main
169 def my_main():
170 pass
172 if __name__ == '__main__':
173 ex.run_commandline()
174 """
175 captured = self.main(function)
176 if function.__module__ == "__main__":
177 # Ensure that automain is not used in interactive mode.
178 import inspect
180 main_filename = inspect.getfile(function)
181 if main_filename == "<stdin>" or (
182 main_filename.startswith("<ipython-input-")
183 and main_filename.endswith(">")
184 ):
185 raise RuntimeError(
186 "Cannot use @ex.automain decorator in "
187 "interactive mode. Use @ex.main instead."
188 )
190 self.run_commandline()
191 return captured
193 def option_hook(self, function):
194 """
195 Decorator for adding an option hook function.
197 An option hook is a function that is called right before a run
198 is created. It receives (and potentially modifies) the options
199 dictionary. That is, the dictionary of commandline options used for
200 this run.
202 Notes
203 -----
204 The decorated function MUST have an argument called options.
206 The options also contain ``'COMMAND'`` and ``'UPDATE'`` entries,
207 but changing them has no effect. Only modification on
208 flags (entries starting with ``'--'``) are considered.
209 """
210 sig = Signature(function)
211 if "options" not in sig.arguments:
212 raise KeyError(
213 "option_hook functions must have an argument called"
214 " 'options', but got {}".format(sig.arguments)
215 )
216 self.option_hooks.append(function)
217 return function
219 # =========================== Public Interface ============================
221 def get_usage(self, program_name=None):
222 """Get the commandline usage string for this experiment."""
223 program_name = os.path.relpath(
224 program_name or sys.argv[0] or "Dummy", self.base_dir
225 )
226 commands = OrderedDict(self.gather_commands())
227 long_usage = format_usage(
228 program_name, self.doc, commands, self.all_cli_options
229 )
230 # internal usage is a workaround because docopt cannot handle spaces
231 # in program names. So for parsing we use 'dummy' as the program name.
232 # for printing help etc. we want to use the actual program name.
233 internal_usage = format_usage("dummy", self.doc, commands, self.all_cli_options)
234 short_usage = printable_usage(long_usage)
235 return short_usage, long_usage, internal_usage
237 def run(
238 self,
239 command_name: Optional[str] = None,
240 config_updates: Optional[dict] = None,
241 named_configs: Sequence[str] = (),
242 info: Optional[dict] = None,
243 meta_info: Optional[dict] = None,
244 options: Optional[dict] = None,
245 ) -> Run:
246 """
247 Run the main function of the experiment or a given command.
249 Parameters
250 ----------
251 command_name
252 Name of the command to be run. Defaults to main function.
254 config_updates
255 Changes to the configuration as a nested dictionary
257 named_configs
258 list of names of named_configs to use
260 info
261 Additional information for this run.
263 meta_info
264 Additional meta information for this run.
266 options
267 Dictionary of options to use
269 Returns
270 -------
271 The Run object corresponding to the finished run.
272 """
273 run = self._create_run(
274 command_name, config_updates, named_configs, info, meta_info, options
275 )
276 run()
277 return run
279 def run_commandline(self, argv=None) -> Optional[Run]:
280 """
281 Run the command-line interface of this experiment.
283 If ``argv`` is omitted it defaults to ``sys.argv``.
285 Parameters
286 ----------
287 argv
288 Command-line as string or list of strings like ``sys.argv``.
290 Returns
291 -------
292 The Run object corresponding to the finished run.
294 """
295 argv = ensure_wellformed_argv(argv)
296 short_usage, usage, internal_usage = self.get_usage()
297 args = docopt(internal_usage, [str(a) for a in argv[1:]], help=False)
299 cmd_name = args.get("COMMAND") or self.default_command
300 config_updates, named_configs = get_config_updates(args["UPDATE"])
302 err = self._check_command(cmd_name)
303 if not args["help"] and err:
304 print(short_usage)
305 print(err)
306 sys.exit(1)
308 if self._handle_help(args, usage):
309 sys.exit()
311 try:
312 return self.run(
313 cmd_name,
314 config_updates,
315 named_configs,
316 info={},
317 meta_info={},
318 options=args,
319 )
320 except Exception as e:
321 if self.current_run:
322 debug = self.current_run.debug
323 else:
324 # The usual command line options are applied after the run
325 # object is built completely. Some exceptions (e.g.
326 # ConfigAddedError) are raised before this. In these cases,
327 # the debug flag must be checked manually.
328 debug = args.get("--debug", False)
330 if debug:
331 # Debug: Don't change behavior, just re-raise exception
332 raise
333 elif self.current_run and self.current_run.pdb:
334 # Print exception and attach pdb debugger
335 import traceback
336 import pdb
338 traceback.print_exception(*sys.exc_info())
339 pdb.post_mortem()
340 else:
341 # Handle pretty printing of exceptions. This includes
342 # filtering the stacktrace and printing the usage, as
343 # specified by the exceptions attributes
344 if isinstance(e, SacredError):
345 print(format_sacred_error(e, short_usage), file=sys.stderr)
346 else:
347 print_filtered_stacktrace()
348 sys.exit(1)
350 def open_resource(self, filename: PathType, mode: str = "r"):
351 """Open a file and also save it as a resource.
353 Opens a file, reports it to the observers as a resource, and returns
354 the opened file.
356 In Sacred terminology a resource is a file that the experiment needed
357 to access during a run. In case of a MongoObserver that means making
358 sure the file is stored in the database (but avoiding duplicates) along
359 its path and md5 sum.
361 This function can only be called during a run, and just calls the
362 :py:meth:`sacred.run.Run.open_resource` method.
364 Parameters
365 ----------
366 filename
367 name of the file that should be opened
368 mode
369 mode that file will be open
371 Returns
372 -------
373 The opened file-object.
374 """
375 assert self.current_run is not None, "Can only be called during a run."
376 return self.current_run.open_resource(filename, mode)
378 def add_resource(self, filename: PathType) -> None:
379 """Add a file as a resource.
381 In Sacred terminology a resource is a file that the experiment needed
382 to access during a run. In case of a MongoObserver that means making
383 sure the file is stored in the database (but avoiding duplicates) along
384 its path and md5 sum.
386 This function can only be called during a run, and just calls the
387 :py:meth:`sacred.run.Run.add_resource` method.
389 Parameters
390 ----------
391 filename
392 name of the file to be stored as a resource
393 """
394 assert self.current_run is not None, "Can only be called during a run."
395 self.current_run.add_resource(filename)
397 def add_artifact(
398 self,
399 filename: PathType,
400 name: Optional[str] = None,
401 metadata: Optional[dict] = None,
402 content_type: Optional[str] = None,
403 ) -> None:
404 """Add a file as an artifact.
406 In Sacred terminology an artifact is a file produced by the experiment
407 run. In case of a MongoObserver that means storing the file in the
408 database.
410 This function can only be called during a run, and just calls the
411 :py:meth:`sacred.run.Run.add_artifact` method.
413 Parameters
414 ----------
415 filename
416 name of the file to be stored as artifact
417 name
418 optionally set the name of the artifact.
419 Defaults to the relative file-path.
420 metadata
421 optionally attach metadata to the artifact.
422 This only has an effect when using the MongoObserver.
423 content_type
424 optionally attach a content-type to the artifact.
425 This only has an effect when using the MongoObserver.
426 """
427 assert self.current_run is not None, "Can only be called during a run."
428 self.current_run.add_artifact(filename, name, metadata, content_type)
430 @property
431 def info(self) -> dict:
432 """Access the info-dict for storing custom information.
434 Only works during a run and is essentially a shortcut to:
436 Example
437 -------
438 ::
440 @ex.capture
441 def my_captured_function(_run):
442 # [...]
443 _run.info # == ex.info
444 """
445 return self.current_run.info
447 def log_scalar(self, name: str, value: float, step: Optional[int] = None) -> None:
448 """
449 Add a new measurement.
451 The measurement will be processed by the MongoDB* observer
452 during a heartbeat event.
453 Other observers are not yet supported.
456 Parameters
457 ----------
458 name
459 The name of the metric, e.g. training.loss
460 value
461 The measured value
462 step
463 The step number (integer), e.g. the iteration number
464 If not specified, an internal counter for each metric
465 is used, incremented by one.
466 """
467 # Method added in change https://github.com/chovanecm/sacred/issues/4
468 # The same as Run.log_scalar
469 self.current_run.log_scalar(name, value, step)
471 def post_process_name(self, name, ingredient):
472 if ingredient == self:
473 # Removes the experiment's path (prefix) from the names
474 # of the gathered items. This means that, for example,
475 # 'experiment.print_config' becomes 'print_config'.
476 return name[len(self.path) + 1 :]
477 return name
479 def get_default_options(self) -> dict:
480 """Get a dictionary of default options as used with run.
482 Returns
483 -------
484 A dictionary containing option keys of the form '--beat_interval'.
485 Their values are boolean if the option is a flag, otherwise None or
486 its default value.
488 """
489 default_options = {}
490 for option in self.all_cli_options:
491 if isinstance(option, CLIOption):
492 if option.is_flag:
493 default_value = False
494 else:
495 default_value = None
496 else: # legacy, should be removed later on.
497 if option.arg is None:
498 default_value = False
499 else:
500 default_value = None
501 default_options[option.get_flag()] = default_value
503 return default_options
505 # =========================== Internal Interface ==========================
507 def _create_run(
508 self,
509 command_name=None,
510 config_updates=None,
511 named_configs=(),
512 info=None,
513 meta_info=None,
514 options=None,
515 ):
516 command_name = command_name or self.default_command
517 if command_name is None:
518 raise RuntimeError(
519 "No command found to be run. Specify a command "
520 "or define a main function."
521 )
523 default_options = self.get_default_options()
524 if options:
525 default_options.update(options)
526 options = default_options
528 # call option hooks
529 for oh in self.option_hooks:
530 oh(options=options)
532 run = create_run(
533 self,
534 command_name,
535 config_updates,
536 named_configs=named_configs,
537 force=options.get(commandline_options.force_option.get_flag(), False),
538 log_level=options.get(commandline_options.loglevel_option.get_flag(), None),
539 )
540 if info is not None:
541 run.info.update(info)
543 run.meta_info["command"] = command_name
544 run.meta_info["options"] = options
545 run.meta_info["named_configs"] = list(named_configs)
546 if config_updates is not None:
547 run.meta_info["config_updates"] = config_updates
549 if meta_info:
550 run.meta_info.update(meta_info)
552 options_list = gather_command_line_options() + self.additional_cli_options
553 for option in options_list:
554 option_value = options.get(option.get_flag(), False)
555 if option_value:
556 option.apply(option_value, run)
558 self.current_run = run
559 return run
561 def _check_command(self, cmd_name):
562 commands = dict(self.gather_commands())
563 if cmd_name is not None and cmd_name not in commands:
564 return (
565 'Error: Command "{}" not found. Available commands are: '
566 "{}".format(cmd_name, ", ".join(commands.keys()))
567 )
569 if cmd_name is None:
570 return (
571 "Error: No command found to be run. Specify a command"
572 " or define main function. Available commands"
573 " are: {}".format(", ".join(commands.keys()))
574 )
576 def _handle_help(self, args, usage):
577 if args["help"] or args["--help"]:
578 if args["COMMAND"] is None:
579 print(usage)
580 return True
581 else:
582 commands = dict(self.gather_commands())
583 print(help_for_command(commands[args["COMMAND"]]))
584 return True
585 return False
588def gather_command_line_options(filter_disabled=None):
589 """Get a sorted list of all CommandLineOption subclasses."""
590 if filter_disabled is None:
591 filter_disabled = not SETTINGS.COMMAND_LINE.SHOW_DISABLED_OPTIONS
593 options = []
594 for opt in get_inheritors(commandline_options.CommandLineOption):
595 warnings.warn(
596 "Subclassing `CommandLineOption` is deprecated. Please "
597 "use the `sacred.cli_option` decorator and pass the function "
598 "to the Experiment constructor."
599 )
600 if filter_disabled and not opt._enabled:
601 continue
602 options.append(opt)
604 options += DEFAULT_COMMAND_LINE_OPTIONS
606 return sorted(options, key=commandline_options.get_name)
609DEFAULT_COMMAND_LINE_OPTIONS = [
610 s3_option,
611 commandline_options.pdb_option,
612 commandline_options.debug_option,
613 file_storage_option,
614 commandline_options.loglevel_option,
615 mongo_db_option,
616 sql_option,
617 commandline_options.capture_option,
618 commandline_options.help_option,
619 commandline_options.print_config_option,
620 commandline_options.name_option,
621 commandline_options.id_option,
622 commandline_options.priority_option,
623 commandline_options.unobserved_option,
624 commandline_options.beat_interval_option,
625 commandline_options.queue_option,
626 commandline_options.force_option,
627 commandline_options.comment_option,
628 commandline_options.enforce_clean_option,
629 tiny_db_option,
630]