Coverage for sacred/sacred/utils.py: 29%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import collections
5import contextlib
6import importlib
7import logging
8import pkgutil
9import re
10import shlex
11import sys
12import threading
13import traceback as tb
14from functools import partial
15from packaging import version
16from typing import Union
17from pathlib import Path
19import wrapt
22__all__ = [
23 "NO_LOGGER",
24 "PYTHON_IDENTIFIER",
25 "CircularDependencyError",
26 "ObserverError",
27 "SacredInterrupt",
28 "TimeoutInterrupt",
29 "create_basic_stream_logger",
30 "recursive_update",
31 "iterate_flattened",
32 "iterate_flattened_separately",
33 "set_by_dotted_path",
34 "get_by_dotted_path",
35 "iter_prefixes",
36 "join_paths",
37 "is_prefix",
38 "convert_to_nested_dict",
39 "convert_camel_case_to_snake_case",
40 "print_filtered_stacktrace",
41 "optional_kwargs_decorator",
42 "get_inheritors",
43 "apply_backspaces_and_linefeeds",
44 "rel_path",
45 "IntervalTimer",
46 "PathType",
47]
49NO_LOGGER = logging.getLogger("ignore")
50NO_LOGGER.disabled = 1
52PATHCHANGE = object()
54PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
56PathType = Union[str, bytes, Path]
59class ObserverError(Exception):
60 """Error that an observer raises but that should not make the run fail."""
63class SacredInterrupt(Exception):
64 """Base-Class for all custom interrupts.
66 For more information see :ref:`custom_interrupts`.
67 """
69 STATUS = "INTERRUPTED"
72class TimeoutInterrupt(SacredInterrupt):
73 """Signal that the experiment timed out.
75 This exception can be used in client code to indicate that the run
76 exceeded its time limit and has been interrupted because of that.
77 The status of the interrupted run will then be set to ``TIMEOUT``.
79 For more information see :ref:`custom_interrupts`.
80 """
82 STATUS = "TIMEOUT"
85class SacredError(Exception):
86 def __init__(
87 self,
88 message,
89 print_traceback=True,
90 filter_traceback="default",
91 print_usage=False,
92 ):
93 super().__init__(message)
94 self.print_traceback = print_traceback
95 if filter_traceback not in ["always", "default", "never"]:
96 raise ValueError(
97 "filter_traceback must be one of 'always', "
98 "'default' or 'never', not " + filter_traceback
99 )
100 self.filter_traceback = filter_traceback
101 self.print_usage = print_usage
104class CircularDependencyError(SacredError):
105 """The ingredients of the current experiment form a circular dependency."""
107 @classmethod
108 @contextlib.contextmanager
109 def track(cls, ingredient):
110 try:
111 yield
112 except CircularDependencyError as e:
113 if not e.__circular_dependency_handled__:
114 if ingredient in e.__ingredients__:
115 e.__circular_dependency_handled__ = True
116 e.__ingredients__.append(ingredient)
117 raise e
119 def __init__(
120 self,
121 message="Circular dependency detected:",
122 ingredients=None,
123 print_traceback=True,
124 filter_traceback="default",
125 print_usage=False,
126 ):
127 super().__init__(
128 message,
129 print_traceback=print_traceback,
130 filter_traceback=filter_traceback,
131 print_usage=print_usage,
132 )
134 if ingredients is None:
135 ingredients = []
136 self.__ingredients__ = ingredients
137 self.__circular_dependency_handled__ = False
139 def __str__(self):
140 return super().__str__() + "->".join(
141 [i.path for i in reversed(self.__ingredients__)]
142 )
145class ConfigError(SacredError):
146 """Pretty prints the conflicting configuration values."""
148 def __init__(
149 self,
150 message,
151 conflicting_configs=(),
152 print_conflicting_configs=True,
153 print_traceback=True,
154 filter_traceback="default",
155 print_usage=False,
156 config=None,
157 ):
158 super().__init__(
159 message,
160 print_traceback=print_traceback,
161 filter_traceback=filter_traceback,
162 print_usage=print_usage,
163 )
164 self.print_conflicting_configs = print_conflicting_configs
166 if isinstance(conflicting_configs, str):
167 conflicting_configs = (conflicting_configs,)
169 self.__conflicting_configs__ = conflicting_configs
170 self.__prefix_handled__ = False
172 if config is None:
173 config = {}
174 self.__config__ = config
176 @classmethod
177 @contextlib.contextmanager
178 def track(cls, config, prefix=None):
179 try:
180 yield
181 except ConfigError as e:
182 if not e.__prefix_handled__:
183 if prefix:
184 e.__conflicting_configs__ = (
185 join_paths(prefix, str(c)) for c in e.__conflicting_configs__
186 )
187 e.__config__ = config
188 e.__prefix_handled__ = True
189 raise e
191 def __str__(self):
192 s = super().__str__()
193 if self.print_conflicting_configs:
194 # Add a list formatted as below to the string s:
195 #
196 # Conflicting configuration values:
197 # a=3
198 # b.c=4
199 s += "\nConflicting configuration values:"
200 for conflicting_config in self.__conflicting_configs__:
201 s += "\n {}={}".format(
202 conflicting_config,
203 get_by_dotted_path(self.__config__, conflicting_config),
204 )
205 return s
208class InvalidConfigError(ConfigError):
209 """Can be raised in the user code if an error in the configuration is detected.
211 Examples
212 --------
213 >>> # Experiment definitions ...
214 ... @ex.automain
215 ... def main(a, b):
216 ... if a != b['a']:
217 ... raise InvalidConfigError(
218 ... 'Need to be equal',
219 ... conflicting_configs=('a', 'b.a'))
220 """
222 pass
225class MissingConfigError(SacredError):
226 """A config value that is needed by a captured function is not present in the provided config."""
228 def __init__(
229 self,
230 message="Configuration values are missing:",
231 missing_configs=(),
232 print_traceback=False,
233 filter_traceback="default",
234 print_usage=True,
235 ):
236 message = "{} {}".format(message, missing_configs)
237 super().__init__(
238 message,
239 print_traceback=print_traceback,
240 filter_traceback=filter_traceback,
241 print_usage=print_usage,
242 )
245class NamedConfigNotFoundError(SacredError):
246 """A named config is not found."""
248 def __init__(
249 self,
250 named_config,
251 message="Named config not found:",
252 available_named_configs=(),
253 print_traceback=False,
254 filter_traceback="default",
255 print_usage=False,
256 ):
257 message = '{} "{}". Available config values are: {}'.format(
258 message, named_config, available_named_configs
259 )
260 super().__init__(
261 message,
262 print_traceback=print_traceback,
263 filter_traceback=filter_traceback,
264 print_usage=print_usage,
265 )
268class ConfigAddedError(ConfigError):
269 SPECIAL_ARGS = {"_log", "_config", "_seed", "__doc__", "config_filename", "_run"}
270 """Special args that show up in the captured args but can never be set
271 by the user"""
273 def __init__(
274 self,
275 conflicting_configs,
276 message="Added new config entry that is not used anywhere",
277 captured_args=(),
278 print_conflicting_configs=True,
279 print_traceback=False,
280 filter_traceback="default",
281 print_usage=False,
282 print_suggestions=True,
283 config=None,
284 ):
285 super().__init__(
286 message,
287 conflicting_configs=conflicting_configs,
288 print_conflicting_configs=print_conflicting_configs,
289 print_traceback=print_traceback,
290 filter_traceback=filter_traceback,
291 print_usage=print_usage,
292 config=config,
293 )
294 self.captured_args = captured_args
295 self.print_suggestions = print_suggestions
297 def __str__(self):
298 s = super().__str__()
299 if self.print_suggestions:
300 possible_keys = set(self.captured_args) - self.SPECIAL_ARGS
301 if possible_keys:
302 s += "\nPossible config keys are: {}".format(possible_keys)
303 return s
306class SignatureError(SacredError, TypeError):
307 """Error that is raised when the passed arguments do not match the functions signature."""
309 def __init__(
310 self,
311 message,
312 print_traceback=True,
313 filter_traceback="always",
314 print_usage=False,
315 ):
316 super().__init__(message, print_traceback, filter_traceback, print_usage)
319class FilteredTracebackException(tb.TracebackException):
320 """Filter out sacred internal tracebacks from an exception traceback."""
322 def __init__(
323 self,
324 exc_type,
325 exc_value,
326 exc_traceback,
327 *,
328 limit=None,
329 lookup_lines=True,
330 capture_locals=False,
331 _seen=None,
332 ):
333 exc_traceback = self._filter_tb(exc_traceback)
334 self._walk_value(exc_value)
335 super().__init__(
336 exc_type,
337 exc_value,
338 exc_traceback,
339 limit=limit,
340 lookup_lines=lookup_lines,
341 capture_locals=capture_locals,
342 _seen=_seen,
343 )
345 def _walk_value(self, obj):
346 if obj.__cause__:
347 obj.__cause__.__traceback__ = self._filter_tb(obj.__cause__.__traceback__)
348 self._walk_value(obj.__cause__)
349 if obj.__context__:
350 obj.__context__.__traceback__ = self._filter_tb(
351 obj.__context__.__traceback__
352 )
353 self._walk_value(obj.__context__)
355 def _filter_tb(self, tb):
356 filtered_tb = []
357 while tb is not None:
358 if not _is_sacred_frame(tb.tb_frame):
359 filtered_tb.append(tb)
360 tb = tb.tb_next
361 if len(filtered_tb) >= 2:
362 for i in range(1, len(filtered_tb)):
363 filtered_tb[i - 1].tb_next = filtered_tb[i]
364 filtered_tb[-1].tb_next = None
365 return filtered_tb[0]
367 def format(self, *, chain=True):
368 for line in super().format(chain=chain):
369 if line == "Traceback (most recent call last):\n":
370 yield "Traceback (most recent calls WITHOUT Sacred internals):\n"
371 else:
372 yield line
375def create_basic_stream_logger():
376 """Sets up a basic stream logger.
378 Configures the root logger to use a
379 `logging.StreamHandler` and sets the logging level to `logging.INFO`.
381 Notes
382 -----
383 This does not change the logger configuration if the root logger
384 already is configured (i.e. `len(getLogger().handlers) > 0`)
385 """
386 logging.basicConfig(
387 level=logging.INFO, format="%(levelname)s - %(name)s - %(message)s"
388 )
389 return logging.getLogger("")
392def recursive_update(d, u):
393 """
394 Given two dictionaries d and u, update dict d recursively.
396 E.g.:
397 d = {'a': {'b' : 1}}
398 u = {'c': 2, 'a': {'d': 3}}
399 => {'a': {'b': 1, 'd': 3}, 'c': 2}
400 """
401 for k, v in u.items():
402 if isinstance(v, collections.abc.Mapping):
403 r = recursive_update(d.get(k, {}), v)
404 d[k] = r
405 else:
406 d[k] = u[k]
407 return d
410def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
411 """
412 Recursively iterate over the items of a dictionary in a special order.
414 First iterate over manually sorted keys and then over all items that are
415 non-dictionary values (sorted by keys), then over the rest
416 (sorted by keys), providing full dotted paths for every leaf.
417 """
418 manually_sorted_keys = manually_sorted_keys or []
420 def get_order(key_and_value):
421 key, value = key_and_value
422 if key in manually_sorted_keys:
423 return 0, manually_sorted_keys.index(key)
424 elif not is_non_empty_dict(value):
425 return 1, key
426 else:
427 return 2, key
429 for key, value in sorted(dictionary.items(), key=get_order):
430 if is_non_empty_dict(value):
431 yield key, PATHCHANGE
432 for k, val in iterate_flattened_separately(value, manually_sorted_keys):
433 yield join_paths(key, k), val
434 else:
435 yield key, value
438def is_non_empty_dict(python_object):
439 return isinstance(python_object, dict) and python_object
442def iterate_flattened(d):
443 """
444 Recursively iterate over the items of a dictionary.
446 Provides a full dotted paths for every leaf.
447 """
448 for key in sorted(d.keys()):
449 value = d[key]
450 if isinstance(value, dict) and value:
451 for k, v in iterate_flattened(d[key]):
452 yield join_paths(key, k), v
453 else:
454 yield key, value
457def set_by_dotted_path(d, path, value):
458 """
459 Set an entry in a nested dict using a dotted path.
461 Will create dictionaries as needed.
463 Examples
464 --------
465 >>> d = {'foo': {'bar': 7}}
466 >>> set_by_dotted_path(d, 'foo.bar', 10)
467 >>> d
468 {'foo': {'bar': 10}}
469 >>> set_by_dotted_path(d, 'foo.d.baz', 3)
470 >>> d
471 {'foo': {'bar': 10, 'd': {'baz': 3}}}
473 """
474 split_path = path.split(".")
475 current_option = d
476 for p in split_path[:-1]:
477 if p not in current_option:
478 current_option[p] = dict()
479 current_option = current_option[p]
480 current_option[split_path[-1]] = value
483def get_by_dotted_path(d, path, default=None):
484 """
485 Get an entry from nested dictionaries using a dotted path.
487 Example
488 -------
489 >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
490 12
491 """
492 if not path:
493 return d
494 split_path = path.split(".")
495 current_option = d
496 for p in split_path:
497 if p not in current_option:
498 return default
499 current_option = current_option[p]
500 return current_option
503def iter_prefixes(path):
504 """
505 Iterate through all (non-empty) prefixes of a dotted path.
507 Example
508 -------
509 >>> list(iter_prefixes('foo.bar.baz'))
510 ['foo', 'foo.bar', 'foo.bar.baz']
511 """
512 split_path = path.split(".")
513 for i in range(1, len(split_path) + 1):
514 yield join_paths(*split_path[:i])
517def join_paths(*parts):
518 """Join different parts together to a valid dotted path."""
519 return ".".join(str(p).strip(".") for p in parts if p)
522def is_prefix(pre_path, path):
523 """Return True if pre_path is a path-prefix of path."""
524 pre_path = pre_path.strip(".")
525 path = path.strip(".")
526 return not pre_path or path.startswith(pre_path + ".")
529def rel_path(base, path):
530 """Return path relative to base."""
531 if base == path:
532 return ""
533 assert is_prefix(base, path), "{} not a prefix of {}".format(base, path)
534 return path[len(base) :].strip(".")
537def convert_to_nested_dict(dotted_dict):
538 """Convert a dict with dotted path keys to corresponding nested dict."""
539 nested_dict = {}
540 for k, v in iterate_flattened(dotted_dict):
541 set_by_dotted_path(nested_dict, k, v)
542 return nested_dict
545def _is_sacred_frame(frame):
546 return frame.f_globals["__name__"].split(".")[0] == "sacred"
549def print_filtered_stacktrace(filter_traceback="default"):
550 print(format_filtered_stacktrace(filter_traceback), file=sys.stderr)
553def format_filtered_stacktrace(filter_traceback="default"):
554 """
555 Returns the traceback as `string`.
557 `filter_traceback` can be one of:
558 - 'always': always filter out sacred internals
559 - 'default': Default behaviour: filter out sacred internals
560 if the exception did not originate from within sacred, and
561 print just the internal stack trace otherwise
562 - 'never': don't filter, always print full traceback
563 - All other values will fall back to 'never'.
564 """
565 exc_type, exc_value, exc_traceback = sys.exc_info()
566 # determine if last exception is from sacred
567 current_tb = exc_traceback
568 while current_tb.tb_next is not None:
569 current_tb = current_tb.tb_next
571 if filter_traceback == "default" and _is_sacred_frame(current_tb.tb_frame):
572 # just print sacred internal trace
573 header = [
574 "Exception originated from within Sacred.\n"
575 "Traceback (most recent calls):\n"
576 ]
577 texts = tb.format_exception(exc_type, exc_value, current_tb)
578 return "".join(header + texts[1:]).strip()
579 elif filter_traceback in ("default", "always"):
580 # print filtered stacktrace
581 tb_exception = FilteredTracebackException(
582 exc_type, exc_value, exc_traceback, limit=None
583 )
584 return "".join(tb_exception.format())
585 elif filter_traceback == "never":
586 # print full stacktrace
587 return "\n".join(tb.format_exception(exc_type, exc_value, exc_traceback))
588 else:
589 raise ValueError("Unknown value for filter_traceback: " + filter_traceback)
592def format_sacred_error(e, short_usage):
593 lines = []
594 if e.print_usage:
595 lines.append(short_usage)
596 if e.print_traceback:
597 lines.append(format_filtered_stacktrace(e.filter_traceback))
598 else:
599 lines.append("\n".join(tb.format_exception_only(type(e), e)))
600 return "\n".join(lines)
603# noinspection PyUnusedLocal
604@wrapt.decorator
605def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
606 # here wrapped is itself a decorator
607 if args: # means it was used as a normal decorator (so just call it)
608 return wrapped(*args, **kwargs)
609 else: # used with kwargs, so we need to return a decorator
610 return partial(wrapped, **kwargs)
613def get_inheritors(cls):
614 """Get a set of all classes that inherit from the given class."""
615 subclasses = set()
616 work = [cls]
617 while work:
618 parent = work.pop()
619 for child in parent.__subclasses__():
620 if child not in subclasses:
621 subclasses.add(child)
622 work.append(child)
623 return subclasses
626# Credit to Zarathustra and epost from stackoverflow
627# Taken from http://stackoverflow.com/a/1176023/1388435
628def convert_camel_case_to_snake_case(name):
629 """Convert CamelCase to snake_case."""
630 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
631 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
634def apply_backspaces_and_linefeeds(text):
635 """
636 Interpret backspaces and linefeeds in text like a terminal would.
638 Interpret text like a terminal by removing backspace and linefeed
639 characters and applying them line by line.
641 If final line ends with a carriage it keeps it to be concatenable with next
642 output chunk.
643 """
644 orig_lines = text.split("\n")
645 orig_lines_len = len(orig_lines)
646 new_lines = []
647 for orig_line_idx, orig_line in enumerate(orig_lines):
648 chars, cursor = [], 0
649 orig_line_len = len(orig_line)
650 for orig_char_idx, orig_char in enumerate(orig_line):
651 if orig_char == "\r" and (
652 orig_char_idx != orig_line_len - 1
653 or orig_line_idx != orig_lines_len - 1
654 ):
655 cursor = 0
656 elif orig_char == "\b":
657 cursor = max(0, cursor - 1)
658 else:
659 if (
660 orig_char == "\r"
661 and orig_char_idx == orig_line_len - 1
662 and orig_line_idx == orig_lines_len - 1
663 ):
664 cursor = len(chars)
665 if cursor == len(chars):
666 chars.append(orig_char)
667 else:
668 chars[cursor] = orig_char
669 cursor += 1
670 new_lines.append("".join(chars))
671 return "\n".join(new_lines)
674def module_exists(modname):
675 """Checks if a module exists without actually importing it."""
676 try:
677 return pkgutil.find_loader(modname) is not None
678 except ImportError:
679 # TODO: Temporary fix for tf 1.14.0.
680 # Should be removed once fixed in tf.
681 return True
684def modules_exist(*modnames):
685 return all(module_exists(m) for m in modnames)
688def module_is_in_cache(modname):
689 """Checks if a module was imported before (is in the import cache)."""
690 return modname in sys.modules
693def parse_version(version_string):
694 """Returns a parsed version string."""
695 return version.parse(version_string)
698def get_package_version(name):
699 """Returns a parsed version string of a package."""
700 version_string = importlib.import_module(name).__version__
701 return parse_version(version_string)
704def ensure_wellformed_argv(argv):
705 if argv is None:
706 argv = sys.argv
707 elif isinstance(argv, str):
708 argv = shlex.split(argv)
709 else:
710 if not isinstance(argv, (list, tuple)):
711 raise ValueError("argv must be str or list, but was {}".format(type(argv)))
712 if not all([isinstance(a, str) for a in argv]):
713 problems = [a for a in argv if not isinstance(a, str)]
714 raise ValueError(
715 "argv must be list of str but contained the "
716 "following elements: {}".format(problems)
717 )
718 return argv
721class IntervalTimer(threading.Thread):
722 @classmethod
723 def create(cls, func, interval=10):
724 stop_event = threading.Event()
725 timer_thread = cls(stop_event, func, interval)
726 return stop_event, timer_thread
728 def __init__(self, event, func, interval=10.0):
729 super().__init__()
730 self.stopped = event
731 self.func = func
732 self.interval = interval
734 def run(self):
735 while not self.stopped.wait(self.interval):
736 self.func()
737 self.func()