Coverage for sacred/sacred/dependencies.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

227 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import functools 

5import hashlib 

6import os.path 

7import re 

8import sys 

9from pathlib import Path 

10 

11import pkg_resources 

12 

13import sacred.optional as opt 

14from sacred import SETTINGS 

15from sacred.utils import iter_prefixes 

16 

17MB = 1048576 

18MODULE_BLACKLIST = set(sys.builtin_module_names) 

19# sadly many builtins are missing from the above, so we list them manually: 

20MODULE_BLACKLIST |= { 

21 None, 

22 "__future__", 

23 "_abcoll", 

24 "_bootlocale", 

25 "_bsddb", 

26 "_bz2", 

27 "_codecs_cn", 

28 "_codecs_hk", 

29 "_codecs_iso2022", 

30 "_codecs_jp", 

31 "_codecs_kr", 

32 "_codecs_tw", 

33 "_collections_abc", 

34 "_compat_pickle", 

35 "_compression", 

36 "_crypt", 

37 "_csv", 

38 "_ctypes", 

39 "_ctypes_test", 

40 "_curses", 

41 "_curses_panel", 

42 "_dbm", 

43 "_decimal", 

44 "_dummy_thread", 

45 "_elementtree", 

46 "_gdbm", 

47 "_hashlib", 

48 "_hotshot", 

49 "_json", 

50 "_lsprof", 

51 "_LWPCookieJar", 

52 "_lzma", 

53 "_markupbase", 

54 "_MozillaCookieJar", 

55 "_multibytecodec", 

56 "_multiprocessing", 

57 "_opcode", 

58 "_osx_support", 

59 "_pydecimal", 

60 "_pyio", 

61 "_sitebuiltins", 

62 "_sqlite3", 

63 "_ssl", 

64 "_strptime", 

65 "_sysconfigdata", 

66 "_sysconfigdata_m", 

67 "_sysconfigdata_nd", 

68 "_testbuffer", 

69 "_testcapi", 

70 "_testimportmultiple", 

71 "_testmultiphase", 

72 "_threading_local", 

73 "_tkinter", 

74 "_weakrefset", 

75 "abc", 

76 "aifc", 

77 "antigravity", 

78 "anydbm", 

79 "argparse", 

80 "ast", 

81 "asynchat", 

82 "asyncio", 

83 "asyncore", 

84 "atexit", 

85 "audiodev", 

86 "audioop", 

87 "base64", 

88 "BaseHTTPServer", 

89 "Bastion", 

90 "bdb", 

91 "binhex", 

92 "bisect", 

93 "bsddb", 

94 "bz2", 

95 "calendar", 

96 "Canvas", 

97 "CDROM", 

98 "cgi", 

99 "CGIHTTPServer", 

100 "cgitb", 

101 "chunk", 

102 "cmath", 

103 "cmd", 

104 "code", 

105 "codecs", 

106 "codeop", 

107 "collections", 

108 "colorsys", 

109 "commands", 

110 "compileall", 

111 "compiler", 

112 "concurrent", 

113 "ConfigParser", 

114 "configparser", 

115 "contextlib", 

116 "Cookie", 

117 "cookielib", 

118 "copy", 

119 "copy_reg", 

120 "copyreg", 

121 "cProfile", 

122 "crypt", 

123 "csv", 

124 "ctypes", 

125 "curses", 

126 "datetime", 

127 "dbhash", 

128 "dbm", 

129 "decimal", 

130 "Dialog", 

131 "difflib", 

132 "dircache", 

133 "dis", 

134 "distutils", 

135 "DLFCN", 

136 "doctest", 

137 "DocXMLRPCServer", 

138 "dumbdbm", 

139 "dummy_thread", 

140 "dummy_threading", 

141 "easy_install", 

142 "email", 

143 "encodings", 

144 "ensurepip", 

145 "enum", 

146 "filecmp", 

147 "FileDialog", 

148 "fileinput", 

149 "FixTk", 

150 "fnmatch", 

151 "formatter", 

152 "fpectl", 

153 "fpformat", 

154 "fractions", 

155 "ftplib", 

156 "functools", 

157 "future_builtins", 

158 "genericpath", 

159 "getopt", 

160 "getpass", 

161 "gettext", 

162 "glob", 

163 "gzip", 

164 "hashlib", 

165 "heapq", 

166 "hmac", 

167 "hotshot", 

168 "html", 

169 "htmlentitydefs", 

170 "htmllib", 

171 "HTMLParser", 

172 "http", 

173 "httplib", 

174 "idlelib", 

175 "ihooks", 

176 "imaplib", 

177 "imghdr", 

178 "imp", 

179 "importlib", 

180 "imputil", 

181 "IN", 

182 "inspect", 

183 "io", 

184 "ipaddress", 

185 "json", 

186 "keyword", 

187 "lib2to3", 

188 "linecache", 

189 "linuxaudiodev", 

190 "locale", 

191 "logging", 

192 "lzma", 

193 "macpath", 

194 "macurl2path", 

195 "mailbox", 

196 "mailcap", 

197 "markupbase", 

198 "md5", 

199 "mhlib", 

200 "mimetools", 

201 "data", 

202 "MimeWriter", 

203 "mimify", 

204 "mmap", 

205 "modulefinder", 

206 "multifile", 

207 "multiprocessing", 

208 "mutex", 

209 "netrc", 

210 "new", 

211 "nis", 

212 "nntplib", 

213 "ntpath", 

214 "nturl2path", 

215 "numbers", 

216 "opcode", 

217 "operator", 

218 "optparse", 

219 "os", 

220 "os2emxpath", 

221 "ossaudiodev", 

222 "parser", 

223 "pathlib", 

224 "pdb", 

225 "pickle", 

226 "pickletools", 

227 "pip", 

228 "pipes", 

229 "pkg_resources", 

230 "pkgutil", 

231 "platform", 

232 "plistlib", 

233 "popen2", 

234 "poplib", 

235 "posixfile", 

236 "posixpath", 

237 "pprint", 

238 "profile", 

239 "pstats", 

240 "pty", 

241 "py_compile", 

242 "pyclbr", 

243 "pydoc", 

244 "pydoc_data", 

245 "pyexpat", 

246 "Queue", 

247 "queue", 

248 "quopri", 

249 "random", 

250 "re", 

251 "readline", 

252 "repr", 

253 "reprlib", 

254 "resource", 

255 "rexec", 

256 "rfc822", 

257 "rlcompleter", 

258 "robotparser", 

259 "runpy", 

260 "sched", 

261 "ScrolledText", 

262 "selectors", 

263 "sets", 

264 "setuptools", 

265 "sgmllib", 

266 "sha", 

267 "shelve", 

268 "shlex", 

269 "shutil", 

270 "signal", 

271 "SimpleDialog", 

272 "SimpleHTTPServer", 

273 "SimpleXMLRPCServer", 

274 "site", 

275 "sitecustomize", 

276 "smtpd", 

277 "smtplib", 

278 "sndhdr", 

279 "socket", 

280 "SocketServer", 

281 "socketserver", 

282 "sqlite3", 

283 "sre", 

284 "sre_compile", 

285 "sre_constants", 

286 "sre_parse", 

287 "ssl", 

288 "stat", 

289 "statistics", 

290 "statvfs", 

291 "string", 

292 "StringIO", 

293 "stringold", 

294 "stringprep", 

295 "struct", 

296 "subprocess", 

297 "sunau", 

298 "sunaudio", 

299 "symbol", 

300 "symtable", 

301 "sysconfig", 

302 "tabnanny", 

303 "tarfile", 

304 "telnetlib", 

305 "tempfile", 

306 "termios", 

307 "test", 

308 "textwrap", 

309 "this", 

310 "threading", 

311 "timeit", 

312 "Tix", 

313 "tkColorChooser", 

314 "tkCommonDialog", 

315 "Tkconstants", 

316 "Tkdnd", 

317 "tkFileDialog", 

318 "tkFont", 

319 "tkinter", 

320 "Tkinter", 

321 "tkMessageBox", 

322 "tkSimpleDialog", 

323 "toaiff", 

324 "token", 

325 "tokenize", 

326 "trace", 

327 "traceback", 

328 "tracemalloc", 

329 "ttk", 

330 "tty", 

331 "turtle", 

332 "types", 

333 "TYPES", 

334 "typing", 

335 "unittest", 

336 "urllib", 

337 "urllib2", 

338 "urlparse", 

339 "user", 

340 "UserDict", 

341 "UserList", 

342 "UserString", 

343 "uu", 

344 "uuid", 

345 "venv", 

346 "warnings", 

347 "wave", 

348 "weakref", 

349 "webbrowser", 

350 "wheel", 

351 "whichdb", 

352 "wsgiref", 

353 "xdrlib", 

354 "xml", 

355 "xmllib", 

356 "xmlrpc", 

357 "xmlrpclib", 

358 "xxlimited", 

359 "zipapp", 

360 "zipfile", 

361} 

362 

363module = type(sys) 

364PEP440_VERSION_PATTERN = re.compile( 

365 r""" 

366^ 

367(\d+!)? # epoch 

368(\d[.\d]*(?<= \d)) # release 

369((?:[abc]|rc)\d+)? # pre-release 

370(?:(\.post\d+))? # post-release 

371(?:(\.dev\d+))? # development release 

372$ 

373""", 

374 flags=re.VERBOSE, 

375) 

376 

377 

378def get_py_file_if_possible(pyc_name): 

379 """Try to retrieve a X.py file for a given X.py[c] file.""" 

380 if pyc_name.endswith((".py", ".so", ".pyd")): 

381 return pyc_name 

382 assert pyc_name.endswith(".pyc") 

383 non_compiled_file = pyc_name[:-1] 

384 if os.path.exists(non_compiled_file): 

385 return non_compiled_file 

386 return pyc_name 

387 

388 

389def get_digest(filename): 

390 """Compute the MD5 hash for a given file.""" 

391 h = hashlib.md5() 

392 with open(filename, "rb") as f: 

393 data = f.read(1 * MB) 

394 while data: 

395 h.update(data) 

396 data = f.read(1 * MB) 

397 return h.hexdigest() 

398 

399 

400def get_commit_if_possible(filename, save_git_info): 

401 """Try to retrieve VCS information for a given file. 

402 

403 Currently only supports git using the gitpython package. 

404 

405 Parameters 

406 ---------- 

407 filename : str 

408 

409 Returns 

410 ------- 

411 path: str 

412 The base path of the repository 

413 commit: str 

414 The commit hash 

415 is_dirty: bool 

416 True if there are uncommitted changes in the repository 

417 """ 

418 if save_git_info is False: 

419 return None, None, None 

420 

421 try: 

422 from git import Repo, InvalidGitRepositoryError 

423 except ImportError as e: 

424 raise ValueError( 

425 "Cannot import git (pip install GitPython).\n" 

426 "Either GitPython or the git executable is missing.\n" 

427 "You can disable git with:\n" 

428 " sacred.Experiment(..., save_git_info=False)" 

429 ) from e 

430 

431 directory = os.path.dirname(filename) 

432 try: 

433 repo = Repo(directory, search_parent_directories=True) 

434 except InvalidGitRepositoryError: 

435 return None, None, None 

436 try: 

437 path = repo.remote().url 

438 except ValueError: 

439 path = "git:/" + repo.working_dir 

440 is_dirty = repo.is_dirty() 

441 commit = repo.head.commit.hexsha 

442 return path, commit, is_dirty 

443 

444 

445@functools.total_ordering 

446class Source: 

447 def __init__(self, filename, digest, repo, commit, isdirty): 

448 self.filename = os.path.realpath(filename) 

449 self.digest = digest 

450 self.repo = repo 

451 self.commit = commit 

452 self.is_dirty = isdirty 

453 

454 @staticmethod 

455 def create(filename, save_git_info=True): 

456 if not filename or not os.path.exists(filename): 

457 raise ValueError('invalid filename or file not found "{}"'.format(filename)) 

458 

459 main_file = get_py_file_if_possible(os.path.abspath(filename)) 

460 repo, commit, is_dirty = get_commit_if_possible(main_file, save_git_info) 

461 return Source(main_file, get_digest(main_file), repo, commit, is_dirty) 

462 

463 def to_json(self, base_dir=None): 

464 if base_dir: 

465 return ( 

466 os.path.relpath(self.filename, os.path.realpath(base_dir)), 

467 self.digest, 

468 ) 

469 else: 

470 return self.filename, self.digest 

471 

472 def __hash__(self): 

473 return hash(self.filename) 

474 

475 def __eq__(self, other): 

476 if isinstance(other, Source): 

477 return self.filename == other.filename 

478 elif isinstance(other, str): 

479 return self.filename == other 

480 else: 

481 return False 

482 

483 def __le__(self, other): 

484 return self.filename.__le__(other.filename) 

485 

486 def __repr__(self): 

487 return "<Source: {}>".format(self.filename) 

488 

489 

490@functools.total_ordering 

491class PackageDependency: 

492 modname_to_dist = {} 

493 

494 def __init__(self, name, version): 

495 self.name = name 

496 self.version = version 

497 

498 def fill_missing_version(self): 

499 if self.version is not None: 

500 return 

501 dist = pkg_resources.working_set.by_key.get(self.name) 

502 self.version = dist.version if dist else None 

503 

504 def to_json(self): 

505 return "{}=={}".format(self.name, self.version or "<unknown>") 

506 

507 def __hash__(self): 

508 return hash(self.name) 

509 

510 def __eq__(self, other): 

511 if isinstance(other, PackageDependency): 

512 return self.name == other.name 

513 else: 

514 return False 

515 

516 def __le__(self, other): 

517 return self.name.__le__(other.name) 

518 

519 def __repr__(self): 

520 return "<PackageDependency: {}={}>".format(self.name, self.version) 

521 

522 @classmethod 

523 def create(cls, mod): 

524 if not cls.modname_to_dist: 

525 # some packagenames don't match the module names (e.g. PyYAML) 

526 # so we set up a dict to map from module name to package name 

527 for dist in pkg_resources.working_set: 

528 try: 

529 toplevel_names = dist._get_metadata("top_level.txt") 

530 for tln in toplevel_names: 

531 cls.modname_to_dist[tln] = dist.project_name, dist.version 

532 except Exception: 

533 pass 

534 

535 name, version = cls.modname_to_dist.get(mod.__name__, (mod.__name__, None)) 

536 

537 return PackageDependency(name, version) 

538 

539 

540def convert_path_to_module_parts(path): 

541 """Convert path to a python file into list of module names.""" 

542 module_parts = list(path.parts) 

543 if module_parts[-1] in ["__init__.py", "__init__.pyc"]: 

544 # remove trailing __init__.py 

545 module_parts = module_parts[:-1] 

546 else: 

547 # remove file extension 

548 module_parts[-1], _ = os.path.splitext(module_parts[-1]) 

549 return module_parts 

550 

551 

552def is_local_source(filename, modname, experiment_path): 

553 """Check if a module comes from the given experiment path. 

554 

555 Check if a module, given by name and filename, is from (a subdirectory of ) 

556 the given experiment path. 

557 This is used to determine if the module is a local source file, or rather 

558 a package dependency. 

559 

560 Parameters 

561 ---------- 

562 filename: str 

563 The absolute filename of the module in question. 

564 (Usually module.__file__) 

565 modname: str 

566 The full name of the module including parent namespaces. 

567 experiment_path: str 

568 The base path of the experiment. 

569 

570 Returns 

571 ------- 

572 bool: 

573 True if the module was imported locally from (a subdir of) the 

574 experiment_path, and False otherwise. 

575 """ 

576 filename = Path(os.path.abspath(os.path.realpath(filename))) 

577 experiment_path = Path(os.path.abspath(os.path.realpath(experiment_path))) 

578 if experiment_path not in filename.parents: 

579 return False 

580 rel_path = filename.relative_to(experiment_path) 

581 path_parts = convert_path_to_module_parts(rel_path) 

582 

583 mod_parts = modname.split(".") 

584 if path_parts == mod_parts: 

585 return True 

586 if len(path_parts) > len(mod_parts): 

587 return False 

588 abs_path_parts = convert_path_to_module_parts(filename) 

589 return all([p == m for p, m in zip(reversed(abs_path_parts), reversed(mod_parts))]) 

590 

591 

592def get_main_file(globs, save_git_info): 

593 filename = globs.get("__file__") 

594 

595 if filename is None: 

596 experiment_path = os.path.abspath(os.path.curdir) 

597 main = None 

598 else: 

599 main = Source.create(globs.get("__file__"), save_git_info) 

600 experiment_path = os.path.dirname(main.filename) 

601 return experiment_path, main 

602 

603 

604def iterate_imported_modules(globs): 

605 checked_modules = set(MODULE_BLACKLIST) 

606 for glob in globs.values(): 

607 if isinstance(glob, module): 

608 mod_path = glob.__name__ 

609 elif hasattr(glob, "__module__"): 

610 mod_path = glob.__module__ 

611 else: 

612 continue # pragma: no cover 

613 

614 if not mod_path: 

615 continue 

616 

617 for modname in iter_prefixes(mod_path): 

618 if modname in checked_modules: 

619 continue 

620 checked_modules.add(modname) 

621 mod = sys.modules.get(modname) 

622 if mod is not None: 

623 yield modname, mod 

624 

625 

626def iterate_all_python_files(base_path): 

627 # TODO support ignored directories/files 

628 for dirname, subdirlist, filelist in os.walk(base_path): 

629 if "__pycache__" in dirname: 

630 continue 

631 for filename in filelist: 

632 if filename.endswith(".py"): 

633 yield os.path.join(base_path, dirname, filename) 

634 

635 

636def iterate_sys_modules(): 

637 items = list(sys.modules.items()) 

638 for modname, mod in items: 

639 if modname not in MODULE_BLACKLIST and mod is not None: 

640 yield modname, mod 

641 

642 

643def get_sources_from_modules(module_iterator, base_path, save_git_info): 

644 sources = set() 

645 for modname, mod in module_iterator: 

646 # hasattr doesn't work with python extensions 

647 if not getattr(mod, "__file__", None): 

648 continue 

649 

650 filename = os.path.abspath(mod.__file__) 

651 if filename not in sources and is_local_source(filename, modname, base_path): 

652 s = Source.create(filename, save_git_info) 

653 sources.add(s) 

654 return sources 

655 

656 

657def get_dependencies_from_modules(module_iterator, base_path): 

658 dependencies = set() 

659 for modname, mod in module_iterator: 

660 # hasattr doesn't work with python extensions 

661 if getattr(mod, "__file__", None) and is_local_source( 

662 os.path.abspath(mod.__file__), modname, base_path 

663 ): 

664 continue 

665 if modname.startswith("_") or "." in modname: 

666 continue 

667 

668 try: 

669 pdep = PackageDependency.create(mod) 

670 if pdep.version is not None: 

671 dependencies.add(pdep) 

672 except AttributeError: 

673 pass 

674 return dependencies 

675 

676 

677def get_sources_from_sys_modules(globs, base_path, save_git_info): 

678 return get_sources_from_modules(iterate_sys_modules(), base_path, save_git_info) 

679 

680 

681def get_sources_from_imported_modules(globs, base_path, save_git_info): 

682 return get_sources_from_modules( 

683 iterate_imported_modules(globs), base_path, save_git_info 

684 ) 

685 

686 

687def get_sources_from_local_dir(globs, base_path, save_git_info): 

688 return { 

689 Source.create(filename, save_git_info) 

690 for filename in iterate_all_python_files(base_path) 

691 } 

692 

693 

694def get_dependencies_from_sys_modules(globs, base_path): 

695 return get_dependencies_from_modules(iterate_sys_modules(), base_path) 

696 

697 

698def get_dependencies_from_imported_modules(globs, base_path): 

699 return get_dependencies_from_modules(iterate_imported_modules(globs), base_path) 

700 

701 

702def get_dependencies_from_pkg(globs, base_path): 

703 dependencies = set() 

704 for dist in pkg_resources.working_set: 

705 if dist.version == "0.0.0": 

706 continue # ugly hack to deal with pkg-resource version bug 

707 dependencies.add(PackageDependency(dist.project_name, dist.version)) 

708 return dependencies 

709 

710 

711source_discovery_strategies = { 

712 "none": lambda *_, **__: set(), 

713 "imported": get_sources_from_imported_modules, 

714 "sys": get_sources_from_sys_modules, 

715 "dir": get_sources_from_local_dir, 

716} 

717 

718dependency_discovery_strategies = { 

719 "none": lambda *_, **__: set(), 

720 "imported": get_dependencies_from_imported_modules, 

721 "sys": get_dependencies_from_sys_modules, 

722 "pkg": get_dependencies_from_pkg, 

723} 

724 

725 

726def gather_sources_and_dependencies(globs, save_git_info, base_dir=None): 

727 """Scan the given globals for modules and return them as dependencies.""" 

728 experiment_path, main = get_main_file(globs, save_git_info) 

729 

730 base_dir = base_dir or experiment_path 

731 

732 gather_sources = source_discovery_strategies[SETTINGS["DISCOVER_SOURCES"]] 

733 sources = gather_sources(globs, base_dir, save_git_info) 

734 if main is not None: 

735 sources.add(main) 

736 

737 gather_dependencies = dependency_discovery_strategies[ 

738 SETTINGS["DISCOVER_DEPENDENCIES"] 

739 ] 

740 dependencies = gather_dependencies(globs, base_dir) 

741 

742 if opt.has_numpy: 

743 # Add numpy as a dependency because it might be used for randomness 

744 dependencies.add(PackageDependency.create(opt.np)) 

745 

746 return main, sources, dependencies