Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/host_info.py: 75%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

105 statements  

1"""Helps to collect information about the host of an experiment.""" 

2 

3import os 

4import platform 

5import re 

6import subprocess 

7from xml.etree import ElementTree 

8import warnings 

9from typing import List 

10 

11import cpuinfo 

12 

13from sacred.utils import optional_kwargs_decorator 

14from sacred.settings import SETTINGS 

15 

16__all__ = ("host_info_gatherers", "get_host_info", "host_info_getter") 

17 

18# Legacy global dict of functions that are used 

19# to collect the host information. 

20host_info_gatherers = {} 

21 

22 

23class IgnoreHostInfo(Exception): 

24 """Used by host_info_getters to signal that this cannot be gathered.""" 

25 

26 

27class HostInfoGetter: 

28 def __init__(self, getter_function, name): 

29 self.getter_function = getter_function 

30 self.name = name 

31 

32 def __call__(self): 

33 return self.getter_function() 

34 

35 def get_info(self): 

36 return self.getter_function() 

37 

38 

39def host_info_gatherer(name): 

40 def wrapper(f): 

41 return HostInfoGetter(f, name) 

42 

43 return wrapper 

44 

45 

46def check_additional_host_info(additional_host_info: List[HostInfoGetter]): 

47 names_taken = [x.name for x in _host_info_gatherers_list] 

48 for getter in additional_host_info: 

49 if getter.name in names_taken: 

50 error_msg = ( 

51 "Key {} used in `additional_host_info` already exists as a " 

52 "default gatherer function. Do not use the following keys: " 

53 "{}" 

54 ).format(getter.name, names_taken) 

55 raise KeyError(error_msg) 

56 

57 

58def get_host_info(additional_host_info: List[HostInfoGetter] = None): 

59 """Collect some information about the machine this experiment runs on. 

60 

61 Returns 

62 ------- 

63 dict 

64 A dictionary with information about the CPU, the OS and the 

65 Python version of this machine. 

66 

67 """ 

68 additional_host_info = additional_host_info or [] 

69 # can't use += because we don't want to modify the mutable argument. 

70 additional_host_info = additional_host_info + _host_info_gatherers_list 

71 all_host_info_gatherers = host_info_gatherers.copy() 

72 for getter in additional_host_info: 

73 all_host_info_gatherers[getter.name] = getter 

74 host_info = {} 

75 for k, v in all_host_info_gatherers.items(): 

76 try: 

77 host_info[k] = v() 

78 except IgnoreHostInfo: 

79 pass 

80 return host_info 

81 

82 

83@optional_kwargs_decorator 

84def host_info_getter(func, name=None): 

85 """ 

86 The decorated function is added to the process of collecting the host_info. 

87 

88 This just adds the decorated function to the global 

89 ``sacred.host_info.host_info_gatherers`` dictionary. 

90 The functions from that dictionary are used when collecting the host info 

91 using :py:func:`~sacred.host_info.get_host_info`. 

92 

93 Parameters 

94 ---------- 

95 func : callable 

96 A function that can be called without arguments and returns some 

97 json-serializable information. 

98 name : str, optional 

99 The name of the corresponding entry in host_info. 

100 Defaults to the name of the function. 

101 

102 Returns 

103 ------- 

104 The function itself. 

105 

106 """ 

107 warnings.warn( 

108 "The host_info_getter is deprecated. " 

109 "Please use the `additional_host_info` argument" 

110 " in the Experiment constructor.", 

111 DeprecationWarning, 

112 ) 

113 name = name or func.__name__ 

114 host_info_gatherers[name] = func 

115 return func 

116 

117 

118# #################### Default Host Information ############################### 

119 

120 

121@host_info_gatherer(name="hostname") 

122def _hostname(): 

123 return platform.node() 

124 

125 

126@host_info_gatherer(name="os") 

127def _os(): 

128 return [platform.system(), platform.platform()] 

129 

130 

131@host_info_gatherer(name="python_version") 

132def _python_version(): 

133 return platform.python_version() 

134 

135 

136@host_info_gatherer(name="cpu") 

137def _cpu(): 

138 if platform.system() == "Windows": 

139 return _get_cpu_by_pycpuinfo() 

140 try: 

141 if platform.system() == "Darwin": 

142 return _get_cpu_by_sysctl() 

143 elif platform.system() == "Linux": 

144 return _get_cpu_by_proc_cpuinfo() 

145 except Exception: 

146 # Use pycpuinfo only if other ways fail, since it takes about 1 sec 

147 return _get_cpu_by_pycpuinfo() 

148 

149 

150@host_info_gatherer(name="gpus") 

151def _gpus(): 

152 if not SETTINGS.HOST_INFO.INCLUDE_GPU_INFO: 

153 return 

154 

155 try: 

156 xml = subprocess.check_output(["nvidia-smi", "-q", "-x"]).decode( 

157 "utf-8", "replace" 

158 ) 

159 except (FileNotFoundError, OSError, subprocess.CalledProcessError): 

160 raise IgnoreHostInfo() 

161 

162 gpu_info = {"gpus": []} 

163 for child in ElementTree.fromstring(xml): 

164 if child.tag == "driver_version": 

165 gpu_info["driver_version"] = child.text 

166 if child.tag != "gpu": 

167 continue 

168 gpu = { 

169 "model": child.find("product_name").text, 

170 "total_memory": int( 

171 child.find("fb_memory_usage").find("total").text.split()[0] 

172 ), 

173 "persistence_mode": (child.find("persistence_mode").text == "Enabled"), 

174 } 

175 gpu_info["gpus"].append(gpu) 

176 

177 return gpu_info 

178 

179 

180@host_info_gatherer(name="ENV") 

181def _environment(): 

182 keys_to_capture = SETTINGS.HOST_INFO.CAPTURED_ENV 

183 return {k: os.environ[k] for k in keys_to_capture if k in os.environ} 

184 

185 

186_host_info_gatherers_list = [_hostname, _os, _python_version, _cpu, _gpus, _environment] 

187 

188# ################### Get CPU Information ############################### 

189 

190 

191def _get_cpu_by_sysctl(): 

192 os.environ["PATH"] += ":/usr/sbin" 

193 command = ["sysctl", "-n", "machdep.cpu.brand_string"] 

194 return subprocess.check_output(command).decode().strip() 

195 

196 

197def _get_cpu_by_proc_cpuinfo(): 

198 command = ["cat", "/proc/cpuinfo"] 

199 all_info = subprocess.check_output(command).decode() 

200 model_pattern = re.compile(r"^\s*model name\s*:") 

201 for line in all_info.split("\n"): 

202 if model_pattern.match(line): 

203 return model_pattern.sub("", line, 1).strip() 

204 

205 

206def _get_cpu_by_pycpuinfo(): 

207 return cpuinfo.get_cpu_info().get("brand", "Unknown")