Coverage for sacred/sacred/observers/file_storage.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import json
5import os
6import os.path
7from pathlib import Path
8from typing import Optional
9import warnings
11from shutil import copyfile, SameFileError
13from sacred.commandline_options import cli_option
14from sacred.dependencies import get_digest
15from sacred.observers.base import RunObserver
16from sacred import optional as opt
17from sacred.serializer import flatten
18from sacred.utils import PathType
21DEFAULT_FILE_STORAGE_PRIORITY = 20
24class FileStorageObserver(RunObserver):
25 VERSION = "FileStorageObserver-0.7.0"
27 @classmethod
28 def create(cls, *args, **kwargs):
29 warnings.warn(
30 "FileStorageObserver.create(...) is deprecated. "
31 "Please use FileStorageObserver(...) instead.",
32 DeprecationWarning,
33 )
34 return cls(*args, **kwargs)
36 def __init__(
37 self,
38 basedir: PathType,
39 resource_dir: Optional[PathType] = None,
40 source_dir: Optional[PathType] = None,
41 template: Optional[PathType] = None,
42 priority: int = DEFAULT_FILE_STORAGE_PRIORITY,
43 copy_artifacts: bool = True,
44 copy_sources: bool = True,
45 ):
46 basedir = Path(basedir)
47 resource_dir = resource_dir or basedir / "_resources"
48 source_dir = source_dir or basedir / "_sources"
49 if template is not None:
50 if not os.path.exists(template):
51 raise FileNotFoundError(
52 "Couldn't find template file '{}'".format(template)
53 )
54 else:
55 template = basedir / "template.html"
56 if not template.exists():
57 template = None
58 self.initialize(
59 basedir,
60 resource_dir,
61 source_dir,
62 template,
63 priority,
64 copy_artifacts,
65 copy_sources,
66 )
68 def initialize(
69 self,
70 basedir,
71 resource_dir,
72 source_dir,
73 template,
74 priority=DEFAULT_FILE_STORAGE_PRIORITY,
75 copy_artifacts=True,
76 copy_sources=True,
77 ):
78 self.basedir = str(basedir)
79 self.resource_dir = resource_dir
80 self.source_dir = source_dir
81 self.template = template
82 self.priority = priority
83 self.copy_artifacts = copy_artifacts
84 self.copy_sources = copy_sources
85 self.dir = None
86 self.run_entry = None
87 self.config = None
88 self.info = None
89 self.cout = ""
90 self.cout_write_cursor = 0
92 @classmethod
93 def create_from(cls, *args, **kwargs):
94 self = cls.__new__(cls) # skip __init__ call
95 self.initialize(*args, **kwargs)
96 return self
98 def _maximum_existing_run_id(self):
99 dir_nrs = [
100 int(d)
101 for d in os.listdir(self.basedir)
102 if os.path.isdir(os.path.join(self.basedir, d)) and d.isdigit()
103 ]
104 if dir_nrs:
105 return max(dir_nrs)
106 else:
107 return 0
109 def _make_dir(self, _id):
110 new_dir = os.path.join(self.basedir, str(_id))
111 os.mkdir(new_dir)
112 self.dir = new_dir # set only if mkdir is successful
114 def _make_run_dir(self, _id):
115 os.makedirs(self.basedir, exist_ok=True)
116 self.dir = None
117 if _id is None:
118 fail_count = 0
119 _id = self._maximum_existing_run_id() + 1
120 while self.dir is None:
121 try:
122 self._make_dir(_id)
123 except FileExistsError: # Catch race conditions
124 if fail_count < 1000:
125 fail_count += 1
126 _id += 1
127 else: # expect that something else went wrong
128 raise
129 else:
130 self.dir = os.path.join(self.basedir, str(_id))
131 os.mkdir(self.dir)
133 def queued_event(
134 self, ex_info, command, host_info, queue_time, config, meta_info, _id
135 ):
136 self._make_run_dir(_id)
138 self.run_entry = {
139 "experiment": dict(ex_info),
140 "command": command,
141 "host": dict(host_info),
142 "meta": meta_info,
143 "status": "QUEUED",
144 }
145 self.config = config
146 self.info = {}
148 self.save_json(self.run_entry, "run.json")
149 self.save_json(self.config, "config.json")
151 if self.copy_sources:
152 for s, _ in ex_info["sources"]:
153 self.save_file(s)
155 return os.path.relpath(self.dir, self.basedir) if _id is None else _id
157 def save_sources(self, ex_info):
158 base_dir = ex_info["base_dir"]
159 source_info = []
160 for s, _ in ex_info["sources"]:
161 abspath = os.path.join(base_dir, s)
162 if self.copy_sources:
163 store_path = self.find_or_save(abspath, self.source_dir)
164 else:
165 store_path = abspath
166 relative_source = os.path.relpath(str(store_path), self.basedir)
167 source_info.append([s, relative_source])
168 return source_info
170 def started_event(
171 self, ex_info, command, host_info, start_time, config, meta_info, _id
172 ):
173 self._make_run_dir(_id)
175 ex_info["sources"] = self.save_sources(ex_info)
177 self.run_entry = {
178 "experiment": dict(ex_info),
179 "command": command,
180 "host": dict(host_info),
181 "start_time": start_time.isoformat(),
182 "meta": meta_info,
183 "status": "RUNNING",
184 "resources": [],
185 "artifacts": [],
186 "heartbeat": None,
187 }
188 self.config = config
189 self.info = {}
190 self.cout = ""
191 self.cout_write_cursor = 0
193 self.save_json(self.run_entry, "run.json")
194 self.save_json(self.config, "config.json")
195 self.save_cout()
197 return os.path.relpath(self.dir, self.basedir) if _id is None else _id
199 def find_or_save(self, filename, store_dir: Path):
200 try:
201 Path(filename).resolve().relative_to(Path(self.basedir).resolve())
202 is_relative_to = True
203 except ValueError:
204 is_relative_to = False
206 if is_relative_to and not self.copy_artifacts:
207 return filename
208 else:
209 store_dir.mkdir(parents=True, exist_ok=True)
210 source_name, ext = os.path.splitext(os.path.basename(filename))
211 md5sum = get_digest(filename)
212 store_name = source_name + "_" + md5sum + ext
213 store_path = store_dir / store_name
214 if not store_path.exists():
215 copyfile(filename, str(store_path))
216 return store_path
218 def save_json(self, obj, filename):
219 with open(os.path.join(self.dir, filename), "w") as f:
220 json.dump(flatten(obj), f, sort_keys=True, indent=2)
221 f.flush()
223 def save_file(self, filename, target_name=None):
224 target_name = target_name or os.path.basename(filename)
225 blacklist = ["run.json", "config.json", "cout.txt", "metrics.json"]
226 blacklist = [os.path.join(self.dir, x) for x in blacklist]
227 dest_file = os.path.join(self.dir, target_name)
228 if dest_file in blacklist:
229 raise FileExistsError(
230 "You are trying to overwrite a file necessary for the "
231 "FileStorageObserver. "
232 "The list of blacklisted files is: {}".format(blacklist)
233 )
234 try:
235 copyfile(filename, dest_file)
236 except SameFileError:
237 pass
239 def save_cout(self):
240 with open(os.path.join(self.dir, "cout.txt"), "ab") as f:
241 f.write(self.cout[self.cout_write_cursor :].encode("utf-8"))
242 self.cout_write_cursor = len(self.cout)
244 def render_template(self):
245 if opt.has_mako and self.template:
246 from mako.template import Template
248 template = Template(filename=self.template)
249 report = template.render(
250 run=self.run_entry,
251 config=self.config,
252 info=self.info,
253 cout=self.cout,
254 savedir=self.dir,
255 )
256 ext = self.template.suffix
257 with open(os.path.join(self.dir, "report" + ext), "w") as f:
258 f.write(report)
260 def heartbeat_event(self, info, captured_out, beat_time, result):
261 self.info = info
262 self.run_entry["heartbeat"] = beat_time.isoformat()
263 self.run_entry["result"] = result
264 self.cout = captured_out
265 self.save_cout()
266 self.save_json(self.run_entry, "run.json")
267 if self.info:
268 self.save_json(self.info, "info.json")
270 def completed_event(self, stop_time, result):
271 self.run_entry["stop_time"] = stop_time.isoformat()
272 self.run_entry["result"] = result
273 self.run_entry["status"] = "COMPLETED"
275 self.save_json(self.run_entry, "run.json")
276 self.render_template()
278 def interrupted_event(self, interrupt_time, status):
279 self.run_entry["stop_time"] = interrupt_time.isoformat()
280 self.run_entry["status"] = status
281 self.save_json(self.run_entry, "run.json")
282 self.render_template()
284 def failed_event(self, fail_time, fail_trace):
285 self.run_entry["stop_time"] = fail_time.isoformat()
286 self.run_entry["status"] = "FAILED"
287 self.run_entry["fail_trace"] = fail_trace
288 self.save_json(self.run_entry, "run.json")
289 self.render_template()
291 def resource_event(self, filename):
292 store_path = self.find_or_save(filename, self.resource_dir)
293 self.run_entry["resources"].append([filename, str(store_path)])
294 self.save_json(self.run_entry, "run.json")
296 def artifact_event(self, name, filename, metadata=None, content_type=None):
297 self.save_file(filename, name)
298 self.run_entry["artifacts"].append(name)
299 self.save_json(self.run_entry, "run.json")
301 def log_metrics(self, metrics_by_name, info):
302 """Store new measurements into metrics.json."""
303 try:
304 metrics_path = os.path.join(self.dir, "metrics.json")
305 with open(metrics_path, "r") as f:
306 saved_metrics = json.load(f)
307 except IOError:
308 # We haven't recorded anything yet. Start Collecting.
309 saved_metrics = {}
311 for metric_name, metric_ptr in metrics_by_name.items():
313 if metric_name not in saved_metrics:
314 saved_metrics[metric_name] = {
315 "values": [],
316 "steps": [],
317 "timestamps": [],
318 }
320 saved_metrics[metric_name]["values"] += metric_ptr["values"]
321 saved_metrics[metric_name]["steps"] += metric_ptr["steps"]
323 # Manually convert them to avoid passing a datetime dtype handler
324 # when we're trying to convert into json.
325 timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]]
326 saved_metrics[metric_name]["timestamps"] += timestamps_norm
328 self.save_json(saved_metrics, "metrics.json")
330 def __eq__(self, other):
331 if isinstance(other, FileStorageObserver):
332 return self.basedir == other.basedir
333 return False
336@cli_option("-F", "--file_storage")
337def file_storage_option(args, run):
338 """Add a file-storage observer to the experiment.
340 The value of the arguement should be the
341 base-directory to write the runs to
342 """
343 run.observers.append(FileStorageObserver(args))