Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/dependencies.py: 77%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import functools
5import hashlib
6import os.path
7import re
8import sys
9from pathlib import Path
11import pkg_resources
13import sacred.optional as opt
14from sacred import SETTINGS
15from sacred.utils import iter_prefixes
17MB = 1048576
18MODULE_BLACKLIST = set(sys.builtin_module_names)
19# sadly many builtins are missing from the above, so we list them manually:
20MODULE_BLACKLIST |= {
21 None,
22 "__future__",
23 "_abcoll",
24 "_bootlocale",
25 "_bsddb",
26 "_bz2",
27 "_codecs_cn",
28 "_codecs_hk",
29 "_codecs_iso2022",
30 "_codecs_jp",
31 "_codecs_kr",
32 "_codecs_tw",
33 "_collections_abc",
34 "_compat_pickle",
35 "_compression",
36 "_crypt",
37 "_csv",
38 "_ctypes",
39 "_ctypes_test",
40 "_curses",
41 "_curses_panel",
42 "_dbm",
43 "_decimal",
44 "_dummy_thread",
45 "_elementtree",
46 "_gdbm",
47 "_hashlib",
48 "_hotshot",
49 "_json",
50 "_lsprof",
51 "_LWPCookieJar",
52 "_lzma",
53 "_markupbase",
54 "_MozillaCookieJar",
55 "_multibytecodec",
56 "_multiprocessing",
57 "_opcode",
58 "_osx_support",
59 "_pydecimal",
60 "_pyio",
61 "_sitebuiltins",
62 "_sqlite3",
63 "_ssl",
64 "_strptime",
65 "_sysconfigdata",
66 "_sysconfigdata_m",
67 "_sysconfigdata_nd",
68 "_testbuffer",
69 "_testcapi",
70 "_testimportmultiple",
71 "_testmultiphase",
72 "_threading_local",
73 "_tkinter",
74 "_weakrefset",
75 "abc",
76 "aifc",
77 "antigravity",
78 "anydbm",
79 "argparse",
80 "ast",
81 "asynchat",
82 "asyncio",
83 "asyncore",
84 "atexit",
85 "audiodev",
86 "audioop",
87 "base64",
88 "BaseHTTPServer",
89 "Bastion",
90 "bdb",
91 "binhex",
92 "bisect",
93 "bsddb",
94 "bz2",
95 "calendar",
96 "Canvas",
97 "CDROM",
98 "cgi",
99 "CGIHTTPServer",
100 "cgitb",
101 "chunk",
102 "cmath",
103 "cmd",
104 "code",
105 "codecs",
106 "codeop",
107 "collections",
108 "colorsys",
109 "commands",
110 "compileall",
111 "compiler",
112 "concurrent",
113 "ConfigParser",
114 "configparser",
115 "contextlib",
116 "Cookie",
117 "cookielib",
118 "copy",
119 "copy_reg",
120 "copyreg",
121 "cProfile",
122 "crypt",
123 "csv",
124 "ctypes",
125 "curses",
126 "datetime",
127 "dbhash",
128 "dbm",
129 "decimal",
130 "Dialog",
131 "difflib",
132 "dircache",
133 "dis",
134 "distutils",
135 "DLFCN",
136 "doctest",
137 "DocXMLRPCServer",
138 "dumbdbm",
139 "dummy_thread",
140 "dummy_threading",
141 "easy_install",
142 "email",
143 "encodings",
144 "ensurepip",
145 "enum",
146 "filecmp",
147 "FileDialog",
148 "fileinput",
149 "FixTk",
150 "fnmatch",
151 "formatter",
152 "fpectl",
153 "fpformat",
154 "fractions",
155 "ftplib",
156 "functools",
157 "future_builtins",
158 "genericpath",
159 "getopt",
160 "getpass",
161 "gettext",
162 "glob",
163 "gzip",
164 "hashlib",
165 "heapq",
166 "hmac",
167 "hotshot",
168 "html",
169 "htmlentitydefs",
170 "htmllib",
171 "HTMLParser",
172 "http",
173 "httplib",
174 "idlelib",
175 "ihooks",
176 "imaplib",
177 "imghdr",
178 "imp",
179 "importlib",
180 "imputil",
181 "IN",
182 "inspect",
183 "io",
184 "ipaddress",
185 "json",
186 "keyword",
187 "lib2to3",
188 "linecache",
189 "linuxaudiodev",
190 "locale",
191 "logging",
192 "lzma",
193 "macpath",
194 "macurl2path",
195 "mailbox",
196 "mailcap",
197 "markupbase",
198 "md5",
199 "mhlib",
200 "mimetools",
201 "data",
202 "MimeWriter",
203 "mimify",
204 "mmap",
205 "modulefinder",
206 "multifile",
207 "multiprocessing",
208 "mutex",
209 "netrc",
210 "new",
211 "nis",
212 "nntplib",
213 "ntpath",
214 "nturl2path",
215 "numbers",
216 "opcode",
217 "operator",
218 "optparse",
219 "os",
220 "os2emxpath",
221 "ossaudiodev",
222 "parser",
223 "pathlib",
224 "pdb",
225 "pickle",
226 "pickletools",
227 "pip",
228 "pipes",
229 "pkg_resources",
230 "pkgutil",
231 "platform",
232 "plistlib",
233 "popen2",
234 "poplib",
235 "posixfile",
236 "posixpath",
237 "pprint",
238 "profile",
239 "pstats",
240 "pty",
241 "py_compile",
242 "pyclbr",
243 "pydoc",
244 "pydoc_data",
245 "pyexpat",
246 "Queue",
247 "queue",
248 "quopri",
249 "random",
250 "re",
251 "readline",
252 "repr",
253 "reprlib",
254 "resource",
255 "rexec",
256 "rfc822",
257 "rlcompleter",
258 "robotparser",
259 "runpy",
260 "sched",
261 "ScrolledText",
262 "selectors",
263 "sets",
264 "setuptools",
265 "sgmllib",
266 "XXshaXX",
267 "shelve",
268 "shlex",
269 "shutil",
270 "signal",
271 "SimpleDialog",
272 "SimpleHTTPServer",
273 "SimpleXMLRPCServer",
274 "site",
275 "sitecustomize",
276 "smtpd",
277 "smtplib",
278 "sndhdr",
279 "socket",
280 "SocketServer",
281 "socketserver",
282 "sqlite3",
283 "sre",
284 "sre_compile",
285 "sre_constants",
286 "sre_parse",
287 "ssl",
288 "stat",
289 "statistics",
290 "statvfs",
291 "string",
292 "StringIO",
293 "stringold",
294 "stringprep",
295 "struct",
296 "subprocess",
297 "sunau",
298 "sunaudio",
299 "symbol",
300 "symtable",
301 "sysconfig",
302 "tabnanny",
303 "tarfile",
304 "telnetlib",
305 "tempfile",
306 "termios",
307 "test",
308 "textwrap",
309 "this",
310 "threading",
311 "timeit",
312 "Tix",
313 "tkColorChooser",
314 "tkCommonDialog",
315 "Tkconstants",
316 "Tkdnd",
317 "tkFileDialog",
318 "tkFont",
319 "tkinter",
320 "Tkinter",
321 "tkMessageBox",
322 "tkSimpleDialog",
323 "toaiff",
324 "token",
325 "tokenize",
326 "trace",
327 "traceback",
328 "tracemalloc",
329 "ttk",
330 "tty",
331 "turtle",
332 "types",
333 "TYPES",
334 "typing",
335 "unittest",
336 "urllib",
337 "urllib2",
338 "urlparse",
339 "user",
340 "UserDict",
341 "UserList",
342 "UserString",
343 "uu",
344 "uuid",
345 "venv",
346 "warnings",
347 "wave",
348 "weakref",
349 "webbrowser",
350 "wheel",
351 "whichdb",
352 "wsgiref",
353 "xdrlib",
354 "xml",
355 "xmllib",
356 "xmlrpc",
357 "xmlrpclib",
358 "xxlimited",
359 "zipapp",
360 "zipfile",
361}
363module = type(sys)
364PEP440_VERSION_PATTERN = re.compile(
365 r"""
366^
367(\d+!)? # epoch
368(\d[.\d]*(?<= \d)) # release
369((?:[abc]|rc)\d+)? # pre-release
370(?:(\.post\d+))? # post-release
371(?:(\.dev\d+))? # development release
372$
373""",
374 flags=re.VERBOSE,
375)
378def get_py_file_if_possible(pyc_name):
379 """Try to retrieve a X.py file for a given X.py[c] file."""
380 if pyc_name.endswith((".py", ".so", ".pyd")):
381 return pyc_name
382 assert pyc_name.endswith(".pyc")
383 non_compiled_file = pyc_name[:-1]
384 if os.path.exists(non_compiled_file):
385 return non_compiled_file
386 return pyc_name
389def get_digest(filename):
390 """Compute the MD5 hash for a given file."""
391 h = hashlib.md5()
392 with open(filename, "rb") as f:
393 data = f.read(1 * MB)
394 while data:
395 h.update(data)
396 data = f.read(1 * MB)
397 return h.hexdigest()
400def get_commit_if_possible(filename, save_git_info):
401 """Try to retrieve VCS information for a given file.
403 Currently only supports git using the gitpython package.
405 Parameters
406 ----------
407 filename : str
409 Returns
410 -------
411 path: str
412 The base path of the repository
413 commit: str
414 The commit hash
415 is_dirty: bool
416 True if there are uncommitted changes in the repository
417 """
418 if save_git_info is False:
419 return None, None, None
421 try:
422 from git import Repo, InvalidGitRepositoryError
423 except ImportError as e:
424 raise ValueError(
425 "Cannot import git (pip install GitPython).\n"
426 "Either GitPython or the git executable is missing.\n"
427 "You can disable git with:\n"
428 " sacred.Experiment(..., save_git_info=False)"
429 ) from e
431 directory = os.path.dirname(filename)
432 try:
433 repo = Repo(directory, search_parent_directories=True)
434 except InvalidGitRepositoryError:
435 return None, None, None
436 try:
437 path = repo.remote().url
438 except ValueError:
439 path = "git:/" + repo.working_dir
440 is_dirty = repo.is_dirty()
441 commit = repo.head.commit.hexsha
442 return path, commit, is_dirty
445@functools.total_ordering
446class Source:
447 def __init__(self, filename, digest, repo, commit, isdirty):
448 self.filename = os.path.realpath(filename)
449 self.digest = digest
450 self.repo = repo
451 self.commit = commit
452 self.is_dirty = isdirty
454 @staticmethod
455 def create(filename, save_git_info=True):
456 if not filename or not os.path.exists(filename):
457 raise ValueError('invalid filename or file not found "{}"'.format(filename))
459 main_file = get_py_file_if_possible(os.path.abspath(filename))
460 repo, commit, is_dirty = get_commit_if_possible(main_file, save_git_info)
461 return Source(main_file, get_digest(main_file), repo, commit, is_dirty)
463 def to_json(self, base_dir=None):
464 if base_dir:
465 return (
466 os.path.relpath(self.filename, os.path.realpath(base_dir)),
467 self.digest,
468 )
469 else:
470 return self.filename, self.digest
472 def __hash__(self):
473 return hash(self.filename)
475 def __eq__(self, other):
476 if isinstance(other, Source):
477 return self.filename == other.filename
478 elif isinstance(other, str):
479 return self.filename == other
480 else:
481 return False
483 def __le__(self, other):
484 return self.filename.__le__(other.filename)
486 def __repr__(self):
487 return "<Source: {}>".format(self.filename)
490@functools.total_ordering
491class PackageDependency:
492 modname_to_dist = {}
494 def __init__(self, name, version):
495 self.name = name
496 self.version = version
498 def fill_missing_version(self):
499 if self.version is not None:
500 return
501 dist = pkg_resources.working_set.by_key.get(self.name)
502 self.version = dist.version if dist else None
504 def to_json(self):
505 return "{}=={}".format(self.name, self.version or "<unknown>")
507 def __hash__(self):
508 return hash(self.name)
510 def __eq__(self, other):
511 if isinstance(other, PackageDependency):
512 return self.name == other.name
513 else:
514 return False
516 def __le__(self, other):
517 return self.name.__le__(other.name)
519 def __repr__(self):
520 return "<PackageDependency: {}={}>".format(self.name, self.version)
522 @classmethod
523 def create(cls, mod):
524 if not cls.modname_to_dist:
525 # some packagenames don't match the module names (e.g. PyYAML)
526 # so we set up a dict to map from module name to package name
527 for dist in pkg_resources.working_set:
528 try:
529 toplevel_names = dist._get_metadata("top_level.txt")
530 for tln in toplevel_names:
531 cls.modname_to_dist[tln] = dist.project_name, dist.version
532 except Exception:
533 pass
535 name, version = cls.modname_to_dist.get(mod.__name__, (mod.__name__, None))
537 return PackageDependency(name, version)
540def convert_path_to_module_parts(path):
541 """Convert path to a python file into list of module names."""
542 module_parts = list(path.parts)
543 if module_parts[-1] in ["__init__.py", "__init__.pyc"]:
544 # remove trailing __init__.py
545 module_parts = module_parts[:-1]
546 else:
547 # remove file extension
548 module_parts[-1], _ = os.path.splitext(module_parts[-1])
549 return module_parts
552def is_local_source(filename, modname, experiment_path):
553 """Check if a module comes from the given experiment path.
555 Check if a module, given by name and filename, is from (a subdirectory of )
556 the given experiment path.
557 This is used to determine if the module is a local source file, or rather
558 a package dependency.
560 Parameters
561 ----------
562 filename: str
563 The absolute filename of the module in question.
564 (Usually module.__file__)
565 modname: str
566 The full name of the module including parent namespaces.
567 experiment_path: str
568 The base path of the experiment.
570 Returns
571 -------
572 bool:
573 True if the module was imported locally from (a subdir of) the
574 experiment_path, and False otherwise.
575 """
576 filename = Path(os.path.abspath(os.path.realpath(filename)))
577 experiment_path = Path(os.path.abspath(os.path.realpath(experiment_path)))
578 if experiment_path not in filename.parents:
579 return False
580 rel_path = filename.relative_to(experiment_path)
581 path_parts = convert_path_to_module_parts(rel_path)
583 mod_parts = modname.split(".")
584 if path_parts == mod_parts:
585 return True
586 if len(path_parts) > len(mod_parts):
587 return False
588 abs_path_parts = convert_path_to_module_parts(filename)
589 return all([p == m for p, m in zip(reversed(abs_path_parts), reversed(mod_parts))])
592def get_main_file(globs, save_git_info):
593 filename = globs.get("__file__")
595 if filename is None:
596 experiment_path = os.path.abspath(os.path.curdir)
597 main = None
598 else:
599 main = Source.create(globs.get("__file__"), save_git_info)
600 experiment_path = os.path.dirname(main.filename)
601 return experiment_path, main
604def iterate_imported_modules(globs):
605 checked_modules = set(MODULE_BLACKLIST)
606 for glob in globs.values():
607 if isinstance(glob, module):
608 mod_path = glob.__name__
609 elif hasattr(glob, "__module__"):
610 mod_path = glob.__module__
611 else:
612 continue # pragma: no cover
614 if not mod_path:
615 continue
617 for modname in iter_prefixes(mod_path):
618 if modname in checked_modules:
619 continue
620 checked_modules.add(modname)
621 mod = sys.modules.get(modname)
622 if mod is not None:
623 yield modname, mod
626def iterate_all_python_files(base_path):
627 # TODO support ignored directories/files
628 for dirname, subdirlist, filelist in os.walk(base_path):
629 if "__pycache__" in dirname:
630 continue
631 for filename in filelist:
632 if filename.endswith(".py"):
633 yield os.path.join(base_path, dirname, filename)
636def iterate_sys_modules():
637 items = list(sys.modules.items())
638 for modname, mod in items:
639 if modname not in MODULE_BLACKLIST and mod is not None:
640 yield modname, mod
643def get_sources_from_modules(module_iterator, base_path, save_git_info):
644 sources = set()
645 for modname, mod in module_iterator:
646 # hasattr doesn't work with python extensions
647 if not getattr(mod, "__file__", None):
648 continue
650 filename = os.path.abspath(mod.__file__)
651 if filename not in sources and is_local_source(filename, modname, base_path):
652 s = Source.create(filename, save_git_info)
653 sources.add(s)
654 return sources
657def get_dependencies_from_modules(module_iterator, base_path):
658 dependencies = set()
659 for modname, mod in module_iterator:
660 # hasattr doesn't work with python extensions
661 if getattr(mod, "__file__", None) and is_local_source(
662 os.path.abspath(mod.__file__), modname, base_path
663 ):
664 continue
665 if modname.startswith("_") or "." in modname:
666 continue
668 try:
669 pdep = PackageDependency.create(mod)
670 if pdep.version is not None:
671 dependencies.add(pdep)
672 except AttributeError:
673 pass
674 return dependencies
677def get_sources_from_sys_modules(globs, base_path, save_git_info):
678 return get_sources_from_modules(iterate_sys_modules(), base_path, save_git_info)
681def get_sources_from_imported_modules(globs, base_path, save_git_info):
682 return get_sources_from_modules(
683 iterate_imported_modules(globs), base_path, save_git_info
684 )
687def get_sources_from_local_dir(globs, base_path, save_git_info):
688 return {
689 Source.create(filename, save_git_info)
690 for filename in iterate_all_python_files(base_path)
691 }
694def get_dependencies_from_sys_modules(globs, base_path):
695 return get_dependencies_from_modules(iterate_sys_modules(), base_path)
698def get_dependencies_from_imported_modules(globs, base_path):
699 return get_dependencies_from_modules(iterate_imported_modules(globs), base_path)
702def get_dependencies_from_pkg(globs, base_path):
703 dependencies = set()
704 for dist in pkg_resources.working_set:
705 if dist.version == "0.0.0":
706 continue # ugly hack to deal with pkg-resource version bug
707 dependencies.add(PackageDependency(dist.project_name, dist.version))
708 return dependencies
711source_discovery_strategies = {
712 "none": lambda *_, **__: set(),
713 "imported": get_sources_from_imported_modules,
714 "sys": get_sources_from_sys_modules,
715 "dir": get_sources_from_local_dir,
716}
718dependency_discovery_strategies = {
719 "none": lambda *_, **__: set(),
720 "imported": get_dependencies_from_imported_modules,
721 "sys": get_dependencies_from_sys_modules,
722 "pkg": get_dependencies_from_pkg,
723}
726def gather_sources_and_dependencies(globs, save_git_info, base_dir=None):
727 """Scan the given globals for modules and return them as dependencies."""
728 experiment_path, main = get_main_file(globs, save_git_info)
730 base_dir = base_dir or experiment_path
732 gather_sources = source_discovery_strategies[SETTINGS["DISCOVER_SOURCES"]]
733 sources = gather_sources(globs, base_dir, save_git_info)
734 if main is not None:
735 sources.add(main)
737 gather_dependencies = dependency_discovery_strategies[
738 SETTINGS["DISCOVER_DEPENDENCIES"]
739 ]
740 dependencies = gather_dependencies(globs, base_dir)
742 if opt.has_numpy:
743 # Add numpy as a dependency because it might be used for randomness
744 dependencies.add(PackageDependency.create(opt.np))
746 return main, sources, dependencies