Coverage for sacred/sacred/utils.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

315 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import collections 

5import contextlib 

6import importlib 

7import logging 

8import pkgutil 

9import re 

10import shlex 

11import sys 

12import threading 

13import traceback as tb 

14from functools import partial 

15from packaging import version 

16from typing import Union 

17from pathlib import Path 

18 

19import wrapt 

20 

21 

22__all__ = [ 

23 "NO_LOGGER", 

24 "PYTHON_IDENTIFIER", 

25 "CircularDependencyError", 

26 "ObserverError", 

27 "SacredInterrupt", 

28 "TimeoutInterrupt", 

29 "create_basic_stream_logger", 

30 "recursive_update", 

31 "iterate_flattened", 

32 "iterate_flattened_separately", 

33 "set_by_dotted_path", 

34 "get_by_dotted_path", 

35 "iter_prefixes", 

36 "join_paths", 

37 "is_prefix", 

38 "convert_to_nested_dict", 

39 "convert_camel_case_to_snake_case", 

40 "print_filtered_stacktrace", 

41 "optional_kwargs_decorator", 

42 "get_inheritors", 

43 "apply_backspaces_and_linefeeds", 

44 "rel_path", 

45 "IntervalTimer", 

46 "PathType", 

47] 

48 

49NO_LOGGER = logging.getLogger("ignore") 

50NO_LOGGER.disabled = 1 

51 

52PATHCHANGE = object() 

53 

54PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$") 

55 

56PathType = Union[str, bytes, Path] 

57 

58 

59class ObserverError(Exception): 

60 """Error that an observer raises but that should not make the run fail.""" 

61 

62 

63class SacredInterrupt(Exception): 

64 """Base-Class for all custom interrupts. 

65 

66 For more information see :ref:`custom_interrupts`. 

67 """ 

68 

69 STATUS = "INTERRUPTED" 

70 

71 

72class TimeoutInterrupt(SacredInterrupt): 

73 """Signal that the experiment timed out. 

74 

75 This exception can be used in client code to indicate that the run 

76 exceeded its time limit and has been interrupted because of that. 

77 The status of the interrupted run will then be set to ``TIMEOUT``. 

78 

79 For more information see :ref:`custom_interrupts`. 

80 """ 

81 

82 STATUS = "TIMEOUT" 

83 

84 

85class SacredError(Exception): 

86 def __init__( 

87 self, 

88 message, 

89 print_traceback=True, 

90 filter_traceback="default", 

91 print_usage=False, 

92 ): 

93 super().__init__(message) 

94 self.print_traceback = print_traceback 

95 if filter_traceback not in ["always", "default", "never"]: 

96 raise ValueError( 

97 "filter_traceback must be one of 'always', " 

98 "'default' or 'never', not " + filter_traceback 

99 ) 

100 self.filter_traceback = filter_traceback 

101 self.print_usage = print_usage 

102 

103 

104class CircularDependencyError(SacredError): 

105 """The ingredients of the current experiment form a circular dependency.""" 

106 

107 @classmethod 

108 @contextlib.contextmanager 

109 def track(cls, ingredient): 

110 try: 

111 yield 

112 except CircularDependencyError as e: 

113 if not e.__circular_dependency_handled__: 

114 if ingredient in e.__ingredients__: 

115 e.__circular_dependency_handled__ = True 

116 e.__ingredients__.append(ingredient) 

117 raise e 

118 

119 def __init__( 

120 self, 

121 message="Circular dependency detected:", 

122 ingredients=None, 

123 print_traceback=True, 

124 filter_traceback="default", 

125 print_usage=False, 

126 ): 

127 super().__init__( 

128 message, 

129 print_traceback=print_traceback, 

130 filter_traceback=filter_traceback, 

131 print_usage=print_usage, 

132 ) 

133 

134 if ingredients is None: 

135 ingredients = [] 

136 self.__ingredients__ = ingredients 

137 self.__circular_dependency_handled__ = False 

138 

139 def __str__(self): 

140 return super().__str__() + "->".join( 

141 [i.path for i in reversed(self.__ingredients__)] 

142 ) 

143 

144 

145class ConfigError(SacredError): 

146 """Pretty prints the conflicting configuration values.""" 

147 

148 def __init__( 

149 self, 

150 message, 

151 conflicting_configs=(), 

152 print_conflicting_configs=True, 

153 print_traceback=True, 

154 filter_traceback="default", 

155 print_usage=False, 

156 config=None, 

157 ): 

158 super().__init__( 

159 message, 

160 print_traceback=print_traceback, 

161 filter_traceback=filter_traceback, 

162 print_usage=print_usage, 

163 ) 

164 self.print_conflicting_configs = print_conflicting_configs 

165 

166 if isinstance(conflicting_configs, str): 

167 conflicting_configs = (conflicting_configs,) 

168 

169 self.__conflicting_configs__ = conflicting_configs 

170 self.__prefix_handled__ = False 

171 

172 if config is None: 

173 config = {} 

174 self.__config__ = config 

175 

176 @classmethod 

177 @contextlib.contextmanager 

178 def track(cls, config, prefix=None): 

179 try: 

180 yield 

181 except ConfigError as e: 

182 if not e.__prefix_handled__: 

183 if prefix: 

184 e.__conflicting_configs__ = ( 

185 join_paths(prefix, str(c)) for c in e.__conflicting_configs__ 

186 ) 

187 e.__config__ = config 

188 e.__prefix_handled__ = True 

189 raise e 

190 

191 def __str__(self): 

192 s = super().__str__() 

193 if self.print_conflicting_configs: 

194 # Add a list formatted as below to the string s: 

195 # 

196 # Conflicting configuration values: 

197 # a=3 

198 # b.c=4 

199 s += "\nConflicting configuration values:" 

200 for conflicting_config in self.__conflicting_configs__: 

201 s += "\n {}={}".format( 

202 conflicting_config, 

203 get_by_dotted_path(self.__config__, conflicting_config), 

204 ) 

205 return s 

206 

207 

208class InvalidConfigError(ConfigError): 

209 """Can be raised in the user code if an error in the configuration is detected. 

210 

211 Examples 

212 -------- 

213 >>> # Experiment definitions ... 

214 ... @ex.automain 

215 ... def main(a, b): 

216 ... if a != b['a']: 

217 ... raise InvalidConfigError( 

218 ... 'Need to be equal', 

219 ... conflicting_configs=('a', 'b.a')) 

220 """ 

221 

222 pass 

223 

224 

225class MissingConfigError(SacredError): 

226 """A config value that is needed by a captured function is not present in the provided config.""" 

227 

228 def __init__( 

229 self, 

230 message="Configuration values are missing:", 

231 missing_configs=(), 

232 print_traceback=False, 

233 filter_traceback="default", 

234 print_usage=True, 

235 ): 

236 message = "{} {}".format(message, missing_configs) 

237 super().__init__( 

238 message, 

239 print_traceback=print_traceback, 

240 filter_traceback=filter_traceback, 

241 print_usage=print_usage, 

242 ) 

243 

244 

245class NamedConfigNotFoundError(SacredError): 

246 """A named config is not found.""" 

247 

248 def __init__( 

249 self, 

250 named_config, 

251 message="Named config not found:", 

252 available_named_configs=(), 

253 print_traceback=False, 

254 filter_traceback="default", 

255 print_usage=False, 

256 ): 

257 message = '{} "{}". Available config values are: {}'.format( 

258 message, named_config, available_named_configs 

259 ) 

260 super().__init__( 

261 message, 

262 print_traceback=print_traceback, 

263 filter_traceback=filter_traceback, 

264 print_usage=print_usage, 

265 ) 

266 

267 

268class ConfigAddedError(ConfigError): 

269 SPECIAL_ARGS = {"_log", "_config", "_seed", "__doc__", "config_filename", "_run"} 

270 """Special args that show up in the captured args but can never be set 

271 by the user""" 

272 

273 def __init__( 

274 self, 

275 conflicting_configs, 

276 message="Added new config entry that is not used anywhere", 

277 captured_args=(), 

278 print_conflicting_configs=True, 

279 print_traceback=False, 

280 filter_traceback="default", 

281 print_usage=False, 

282 print_suggestions=True, 

283 config=None, 

284 ): 

285 super().__init__( 

286 message, 

287 conflicting_configs=conflicting_configs, 

288 print_conflicting_configs=print_conflicting_configs, 

289 print_traceback=print_traceback, 

290 filter_traceback=filter_traceback, 

291 print_usage=print_usage, 

292 config=config, 

293 ) 

294 self.captured_args = captured_args 

295 self.print_suggestions = print_suggestions 

296 

297 def __str__(self): 

298 s = super().__str__() 

299 if self.print_suggestions: 

300 possible_keys = set(self.captured_args) - self.SPECIAL_ARGS 

301 if possible_keys: 

302 s += "\nPossible config keys are: {}".format(possible_keys) 

303 return s 

304 

305 

306class SignatureError(SacredError, TypeError): 

307 """Error that is raised when the passed arguments do not match the functions signature.""" 

308 

309 def __init__( 

310 self, 

311 message, 

312 print_traceback=True, 

313 filter_traceback="always", 

314 print_usage=False, 

315 ): 

316 super().__init__(message, print_traceback, filter_traceback, print_usage) 

317 

318 

319class FilteredTracebackException(tb.TracebackException): 

320 """Filter out sacred internal tracebacks from an exception traceback.""" 

321 

322 def __init__( 

323 self, 

324 exc_type, 

325 exc_value, 

326 exc_traceback, 

327 *, 

328 limit=None, 

329 lookup_lines=True, 

330 capture_locals=False, 

331 _seen=None, 

332 ): 

333 exc_traceback = self._filter_tb(exc_traceback) 

334 self._walk_value(exc_value) 

335 super().__init__( 

336 exc_type, 

337 exc_value, 

338 exc_traceback, 

339 limit=limit, 

340 lookup_lines=lookup_lines, 

341 capture_locals=capture_locals, 

342 _seen=_seen, 

343 ) 

344 

345 def _walk_value(self, obj): 

346 if obj.__cause__: 

347 obj.__cause__.__traceback__ = self._filter_tb(obj.__cause__.__traceback__) 

348 self._walk_value(obj.__cause__) 

349 if obj.__context__: 

350 obj.__context__.__traceback__ = self._filter_tb( 

351 obj.__context__.__traceback__ 

352 ) 

353 self._walk_value(obj.__context__) 

354 

355 def _filter_tb(self, tb): 

356 filtered_tb = [] 

357 while tb is not None: 

358 if not _is_sacred_frame(tb.tb_frame): 

359 filtered_tb.append(tb) 

360 tb = tb.tb_next 

361 if len(filtered_tb) >= 2: 

362 for i in range(1, len(filtered_tb)): 

363 filtered_tb[i - 1].tb_next = filtered_tb[i] 

364 filtered_tb[-1].tb_next = None 

365 return filtered_tb[0] 

366 

367 def format(self, *, chain=True): 

368 for line in super().format(chain=chain): 

369 if line == "Traceback (most recent call last):\n": 

370 yield "Traceback (most recent calls WITHOUT Sacred internals):\n" 

371 else: 

372 yield line 

373 

374 

375def create_basic_stream_logger(): 

376 """Sets up a basic stream logger. 

377 

378 Configures the root logger to use a 

379 `logging.StreamHandler` and sets the logging level to `logging.INFO`. 

380 

381 Notes 

382 ----- 

383 This does not change the logger configuration if the root logger 

384 already is configured (i.e. `len(getLogger().handlers) > 0`) 

385 """ 

386 logging.basicConfig( 

387 level=logging.INFO, format="%(levelname)s - %(name)s - %(message)s" 

388 ) 

389 return logging.getLogger("") 

390 

391 

392def recursive_update(d, u): 

393 """ 

394 Given two dictionaries d and u, update dict d recursively. 

395 

396 E.g.: 

397 d = {'a': {'b' : 1}} 

398 u = {'c': 2, 'a': {'d': 3}} 

399 => {'a': {'b': 1, 'd': 3}, 'c': 2} 

400 """ 

401 for k, v in u.items(): 

402 if isinstance(v, collections.abc.Mapping): 

403 r = recursive_update(d.get(k, {}), v) 

404 d[k] = r 

405 else: 

406 d[k] = u[k] 

407 return d 

408 

409 

410def iterate_flattened_separately(dictionary, manually_sorted_keys=None): 

411 """ 

412 Recursively iterate over the items of a dictionary in a special order. 

413 

414 First iterate over manually sorted keys and then over all items that are 

415 non-dictionary values (sorted by keys), then over the rest 

416 (sorted by keys), providing full dotted paths for every leaf. 

417 """ 

418 manually_sorted_keys = manually_sorted_keys or [] 

419 

420 def get_order(key_and_value): 

421 key, value = key_and_value 

422 if key in manually_sorted_keys: 

423 return 0, manually_sorted_keys.index(key) 

424 elif not is_non_empty_dict(value): 

425 return 1, key 

426 else: 

427 return 2, key 

428 

429 for key, value in sorted(dictionary.items(), key=get_order): 

430 if is_non_empty_dict(value): 

431 yield key, PATHCHANGE 

432 for k, val in iterate_flattened_separately(value, manually_sorted_keys): 

433 yield join_paths(key, k), val 

434 else: 

435 yield key, value 

436 

437 

438def is_non_empty_dict(python_object): 

439 return isinstance(python_object, dict) and python_object 

440 

441 

442def iterate_flattened(d): 

443 """ 

444 Recursively iterate over the items of a dictionary. 

445 

446 Provides a full dotted paths for every leaf. 

447 """ 

448 for key in sorted(d.keys()): 

449 value = d[key] 

450 if isinstance(value, dict) and value: 

451 for k, v in iterate_flattened(d[key]): 

452 yield join_paths(key, k), v 

453 else: 

454 yield key, value 

455 

456 

457def set_by_dotted_path(d, path, value): 

458 """ 

459 Set an entry in a nested dict using a dotted path. 

460 

461 Will create dictionaries as needed. 

462 

463 Examples 

464 -------- 

465 >>> d = {'foo': {'bar': 7}} 

466 >>> set_by_dotted_path(d, 'foo.bar', 10) 

467 >>> d 

468 {'foo': {'bar': 10}} 

469 >>> set_by_dotted_path(d, 'foo.d.baz', 3) 

470 >>> d 

471 {'foo': {'bar': 10, 'd': {'baz': 3}}} 

472 

473 """ 

474 split_path = path.split(".") 

475 current_option = d 

476 for p in split_path[:-1]: 

477 if p not in current_option: 

478 current_option[p] = dict() 

479 current_option = current_option[p] 

480 current_option[split_path[-1]] = value 

481 

482 

483def get_by_dotted_path(d, path, default=None): 

484 """ 

485 Get an entry from nested dictionaries using a dotted path. 

486 

487 Example 

488 ------- 

489 >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a') 

490 12 

491 """ 

492 if not path: 

493 return d 

494 split_path = path.split(".") 

495 current_option = d 

496 for p in split_path: 

497 if p not in current_option: 

498 return default 

499 current_option = current_option[p] 

500 return current_option 

501 

502 

503def iter_prefixes(path): 

504 """ 

505 Iterate through all (non-empty) prefixes of a dotted path. 

506 

507 Example 

508 ------- 

509 >>> list(iter_prefixes('foo.bar.baz')) 

510 ['foo', 'foo.bar', 'foo.bar.baz'] 

511 """ 

512 split_path = path.split(".") 

513 for i in range(1, len(split_path) + 1): 

514 yield join_paths(*split_path[:i]) 

515 

516 

517def join_paths(*parts): 

518 """Join different parts together to a valid dotted path.""" 

519 return ".".join(str(p).strip(".") for p in parts if p) 

520 

521 

522def is_prefix(pre_path, path): 

523 """Return True if pre_path is a path-prefix of path.""" 

524 pre_path = pre_path.strip(".") 

525 path = path.strip(".") 

526 return not pre_path or path.startswith(pre_path + ".") 

527 

528 

529def rel_path(base, path): 

530 """Return path relative to base.""" 

531 if base == path: 

532 return "" 

533 assert is_prefix(base, path), "{} not a prefix of {}".format(base, path) 

534 return path[len(base) :].strip(".") 

535 

536 

537def convert_to_nested_dict(dotted_dict): 

538 """Convert a dict with dotted path keys to corresponding nested dict.""" 

539 nested_dict = {} 

540 for k, v in iterate_flattened(dotted_dict): 

541 set_by_dotted_path(nested_dict, k, v) 

542 return nested_dict 

543 

544 

545def _is_sacred_frame(frame): 

546 return frame.f_globals["__name__"].split(".")[0] == "sacred" 

547 

548 

549def print_filtered_stacktrace(filter_traceback="default"): 

550 print(format_filtered_stacktrace(filter_traceback), file=sys.stderr) 

551 

552 

553def format_filtered_stacktrace(filter_traceback="default"): 

554 """ 

555 Returns the traceback as `string`. 

556 

557 `filter_traceback` can be one of: 

558 - 'always': always filter out sacred internals 

559 - 'default': Default behaviour: filter out sacred internals 

560 if the exception did not originate from within sacred, and 

561 print just the internal stack trace otherwise 

562 - 'never': don't filter, always print full traceback 

563 - All other values will fall back to 'never'. 

564 """ 

565 exc_type, exc_value, exc_traceback = sys.exc_info() 

566 # determine if last exception is from sacred 

567 current_tb = exc_traceback 

568 while current_tb.tb_next is not None: 

569 current_tb = current_tb.tb_next 

570 

571 if filter_traceback == "default" and _is_sacred_frame(current_tb.tb_frame): 

572 # just print sacred internal trace 

573 header = [ 

574 "Exception originated from within Sacred.\n" 

575 "Traceback (most recent calls):\n" 

576 ] 

577 texts = tb.format_exception(exc_type, exc_value, current_tb) 

578 return "".join(header + texts[1:]).strip() 

579 elif filter_traceback in ("default", "always"): 

580 # print filtered stacktrace 

581 tb_exception = FilteredTracebackException( 

582 exc_type, exc_value, exc_traceback, limit=None 

583 ) 

584 return "".join(tb_exception.format()) 

585 elif filter_traceback == "never": 

586 # print full stacktrace 

587 return "\n".join(tb.format_exception(exc_type, exc_value, exc_traceback)) 

588 else: 

589 raise ValueError("Unknown value for filter_traceback: " + filter_traceback) 

590 

591 

592def format_sacred_error(e, short_usage): 

593 lines = [] 

594 if e.print_usage: 

595 lines.append(short_usage) 

596 if e.print_traceback: 

597 lines.append(format_filtered_stacktrace(e.filter_traceback)) 

598 else: 

599 lines.append("\n".join(tb.format_exception_only(type(e), e))) 

600 return "\n".join(lines) 

601 

602 

603# noinspection PyUnusedLocal 

604@wrapt.decorator 

605def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None): 

606 # here wrapped is itself a decorator 

607 if args: # means it was used as a normal decorator (so just call it) 

608 return wrapped(*args, **kwargs) 

609 else: # used with kwargs, so we need to return a decorator 

610 return partial(wrapped, **kwargs) 

611 

612 

613def get_inheritors(cls): 

614 """Get a set of all classes that inherit from the given class.""" 

615 subclasses = set() 

616 work = [cls] 

617 while work: 

618 parent = work.pop() 

619 for child in parent.__subclasses__(): 

620 if child not in subclasses: 

621 subclasses.add(child) 

622 work.append(child) 

623 return subclasses 

624 

625 

626# Credit to Zarathustra and epost from stackoverflow 

627# Taken from http://stackoverflow.com/a/1176023/1388435 

628def convert_camel_case_to_snake_case(name): 

629 """Convert CamelCase to snake_case.""" 

630 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 

631 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 

632 

633 

634def apply_backspaces_and_linefeeds(text): 

635 """ 

636 Interpret backspaces and linefeeds in text like a terminal would. 

637 

638 Interpret text like a terminal by removing backspace and linefeed 

639 characters and applying them line by line. 

640 

641 If final line ends with a carriage it keeps it to be concatenable with next 

642 output chunk. 

643 """ 

644 orig_lines = text.split("\n") 

645 orig_lines_len = len(orig_lines) 

646 new_lines = [] 

647 for orig_line_idx, orig_line in enumerate(orig_lines): 

648 chars, cursor = [], 0 

649 orig_line_len = len(orig_line) 

650 for orig_char_idx, orig_char in enumerate(orig_line): 

651 if orig_char == "\r" and ( 

652 orig_char_idx != orig_line_len - 1 

653 or orig_line_idx != orig_lines_len - 1 

654 ): 

655 cursor = 0 

656 elif orig_char == "\b": 

657 cursor = max(0, cursor - 1) 

658 else: 

659 if ( 

660 orig_char == "\r" 

661 and orig_char_idx == orig_line_len - 1 

662 and orig_line_idx == orig_lines_len - 1 

663 ): 

664 cursor = len(chars) 

665 if cursor == len(chars): 

666 chars.append(orig_char) 

667 else: 

668 chars[cursor] = orig_char 

669 cursor += 1 

670 new_lines.append("".join(chars)) 

671 return "\n".join(new_lines) 

672 

673 

674def module_exists(modname): 

675 """Checks if a module exists without actually importing it.""" 

676 try: 

677 return pkgutil.find_loader(modname) is not None 

678 except ImportError: 

679 # TODO: Temporary fix for tf 1.14.0. 

680 # Should be removed once fixed in tf. 

681 return True 

682 

683 

684def modules_exist(*modnames): 

685 return all(module_exists(m) for m in modnames) 

686 

687 

688def module_is_in_cache(modname): 

689 """Checks if a module was imported before (is in the import cache).""" 

690 return modname in sys.modules 

691 

692 

693def parse_version(version_string): 

694 """Returns a parsed version string.""" 

695 return version.parse(version_string) 

696 

697 

698def get_package_version(name): 

699 """Returns a parsed version string of a package.""" 

700 version_string = importlib.import_module(name).__version__ 

701 return parse_version(version_string) 

702 

703 

704def ensure_wellformed_argv(argv): 

705 if argv is None: 

706 argv = sys.argv 

707 elif isinstance(argv, str): 

708 argv = shlex.split(argv) 

709 else: 

710 if not isinstance(argv, (list, tuple)): 

711 raise ValueError("argv must be str or list, but was {}".format(type(argv))) 

712 if not all([isinstance(a, str) for a in argv]): 

713 problems = [a for a in argv if not isinstance(a, str)] 

714 raise ValueError( 

715 "argv must be list of str but contained the " 

716 "following elements: {}".format(problems) 

717 ) 

718 return argv 

719 

720 

721class IntervalTimer(threading.Thread): 

722 @classmethod 

723 def create(cls, func, interval=10): 

724 stop_event = threading.Event() 

725 timer_thread = cls(stop_event, func, interval) 

726 return stop_event, timer_thread 

727 

728 def __init__(self, event, func, interval=10.0): 

729 super().__init__() 

730 self.stopped = event 

731 self.func = func 

732 self.interval = interval 

733 

734 def run(self): 

735 while not self.stopped.wait(self.interval): 

736 self.func() 

737 self.func()