Coverage for sacred/sacred/host_info.py: 41%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Helps to collect information about the host of an experiment."""
3import os
4import platform
5import re
6import subprocess
7from xml.etree import ElementTree
8import warnings
9from typing import List
11import cpuinfo
13from sacred.utils import optional_kwargs_decorator
14from sacred.settings import SETTINGS
16__all__ = ("host_info_gatherers", "get_host_info", "host_info_getter")
18# Legacy global dict of functions that are used
19# to collect the host information.
20host_info_gatherers = {}
23class IgnoreHostInfo(Exception):
24 """Used by host_info_getters to signal that this cannot be gathered."""
27class HostInfoGetter:
28 def __init__(self, getter_function, name):
29 self.getter_function = getter_function
30 self.name = name
32 def __call__(self):
33 return self.getter_function()
35 def get_info(self):
36 return self.getter_function()
39def host_info_gatherer(name):
40 def wrapper(f):
41 return HostInfoGetter(f, name)
43 return wrapper
46def check_additional_host_info(additional_host_info: List[HostInfoGetter]):
47 names_taken = [x.name for x in _host_info_gatherers_list]
48 for getter in additional_host_info:
49 if getter.name in names_taken:
50 error_msg = (
51 "Key {} used in `additional_host_info` already exists as a "
52 "default gatherer function. Do not use the following keys: "
53 "{}"
54 ).format(getter.name, names_taken)
55 raise KeyError(error_msg)
58def get_host_info(additional_host_info: List[HostInfoGetter] = None):
59 """Collect some information about the machine this experiment runs on.
61 Returns
62 -------
63 dict
64 A dictionary with information about the CPU, the OS and the
65 Python version of this machine.
67 """
68 additional_host_info = additional_host_info or []
69 # can't use += because we don't want to modify the mutable argument.
70 additional_host_info = additional_host_info + _host_info_gatherers_list
71 all_host_info_gatherers = host_info_gatherers.copy()
72 for getter in additional_host_info:
73 all_host_info_gatherers[getter.name] = getter
74 host_info = {}
75 for k, v in all_host_info_gatherers.items():
76 try:
77 host_info[k] = v()
78 except IgnoreHostInfo:
79 pass
80 return host_info
83@optional_kwargs_decorator
84def host_info_getter(func, name=None):
85 """
86 The decorated function is added to the process of collecting the host_info.
88 This just adds the decorated function to the global
89 ``sacred.host_info.host_info_gatherers`` dictionary.
90 The functions from that dictionary are used when collecting the host info
91 using :py:func:`~sacred.host_info.get_host_info`.
93 Parameters
94 ----------
95 func : callable
96 A function that can be called without arguments and returns some
97 json-serializable information.
98 name : str, optional
99 The name of the corresponding entry in host_info.
100 Defaults to the name of the function.
102 Returns
103 -------
104 The function itself.
106 """
107 warnings.warn(
108 "The host_info_getter is deprecated. "
109 "Please use the `additional_host_info` argument"
110 " in the Experiment constructor.",
111 DeprecationWarning,
112 )
113 name = name or func.__name__
114 host_info_gatherers[name] = func
115 return func
118# #################### Default Host Information ###############################
121@host_info_gatherer(name="hostname")
122def _hostname():
123 return platform.node()
126@host_info_gatherer(name="os")
127def _os():
128 return [platform.system(), platform.platform()]
131@host_info_gatherer(name="python_version")
132def _python_version():
133 return platform.python_version()
136@host_info_gatherer(name="cpu")
137def _cpu():
138 if platform.system() == "Windows":
139 return _get_cpu_by_pycpuinfo()
140 try:
141 if platform.system() == "Darwin":
142 return _get_cpu_by_sysctl()
143 elif platform.system() == "Linux":
144 return _get_cpu_by_proc_cpuinfo()
145 except Exception:
146 # Use pycpuinfo only if other ways fail, since it takes about 1 sec
147 return _get_cpu_by_pycpuinfo()
150@host_info_gatherer(name="gpus")
151def _gpus():
152 if not SETTINGS.HOST_INFO.INCLUDE_GPU_INFO:
153 return
155 try:
156 xml = subprocess.check_output(["nvidia-smi", "-q", "-x"]).decode(
157 "utf-8", "replace"
158 )
159 except (FileNotFoundError, OSError, subprocess.CalledProcessError):
160 raise IgnoreHostInfo()
162 gpu_info = {"gpus": []}
163 for child in ElementTree.fromstring(xml):
164 if child.tag == "driver_version":
165 gpu_info["driver_version"] = child.text
166 if child.tag != "gpu":
167 continue
168 gpu = {
169 "model": child.find("product_name").text,
170 "total_memory": int(
171 child.find("fb_memory_usage").find("total").text.split()[0]
172 ),
173 "persistence_mode": (child.find("persistence_mode").text == "Enabled"),
174 }
175 gpu_info["gpus"].append(gpu)
177 return gpu_info
180@host_info_gatherer(name="ENV")
181def _environment():
182 keys_to_capture = SETTINGS.HOST_INFO.CAPTURED_ENV
183 return {k: os.environ[k] for k in keys_to_capture if k in os.environ}
186_host_info_gatherers_list = [_hostname, _os, _python_version, _cpu, _gpus, _environment]
188# ################### Get CPU Information ###############################
191def _get_cpu_by_sysctl():
192 os.environ["PATH"] += ":/usr/sbin"
193 command = ["sysctl", "-n", "machdep.cpu.brand_string"]
194 return subprocess.check_output(command).decode().strip()
197def _get_cpu_by_proc_cpuinfo():
198 command = ["cat", "/proc/cpuinfo"]
199 all_info = subprocess.check_output(command).decode()
200 model_pattern = re.compile(r"^\s*model name\s*:")
201 for line in all_info.split("\n"):
202 if model_pattern.match(line):
203 return model_pattern.sub("", line, 1).strip()
206def _get_cpu_by_pycpuinfo():
207 return cpuinfo.get_cpu_info().get("brand", "Unknown")