1#!/usr/bin/env python
2# coding=utf-8
3
4from __future__ import division, print_function, unicode_literals, absolute_import
5
6import os
7import textwrap
8import uuid
9from collections import OrderedDict
10import warnings
11
12from sacred.__about__ import __version__
13from sacred.commandline_options import cli_option
14from sacred.observers import RunObserver
15
16
17class TinyDbObserver(RunObserver):
18
19 VERSION = "TinyDbObserver-{}".format(__version__)
20
21 @classmethod
22 def create(cls, path="./runs_db", overwrite=None):
23 warnings.warn(
24 "TinyDbObserver.create(...) is deprecated. "
25 "Please use TinyDbObserver(...) instead.",
26 DeprecationWarning,
27 )
28 return cls(path, overwrite)
29
30 def __init__(self, path="./runs_db", overwrite=None):
31 from .bases import get_db_file_manager
32
33 root_dir = os.path.abspath(path)
34 os.makedirs(root_dir, exist_ok=True)
35
36 db, fs = get_db_file_manager(root_dir)
37 self.db = db
38 self.runs = db.table("runs")
39 self.fs = fs
40 self.overwrite = overwrite
41 self.run_entry = {}
42 self.db_run_id = None
43 self.root = root_dir
44
45 @classmethod
46 def create_from(cls, db, fs, overwrite=None, root=None):
47 """Instantiate a TinyDbObserver with an existing db and filesystem."""
48 self = cls.__new__(cls) # skip __init__ call
49 self.db = db
50 self.runs = db.table("runs")
51 self.fs = fs
52 self.overwrite = overwrite
53 self.run_entry = {}
54 self.db_run_id = None
55 self.root = root
56 return self
57
58 def save(self):
59 """Insert or update the current run entry."""
60 if self.db_run_id:
61 self.runs.update(self.run_entry, doc_ids=[self.db_run_id])
62 else:
63 db_run_id = self.runs.insert(self.run_entry)
64 self.db_run_id = db_run_id
65
66 def save_sources(self, ex_info):
67 from .bases import BufferedReaderWrapper
68
69 source_info = []
70 for source_name, md5 in ex_info["sources"]:
71
72 # Substitute any HOME or Environment Vars to get absolute path
73 abs_path = os.path.join(ex_info["base_dir"], source_name)
74 abs_path = os.path.expanduser(abs_path)
75 abs_path = os.path.expandvars(abs_path)
76 handle = BufferedReaderWrapper(open(abs_path, "rb"))
77
78 file = self.fs.get(md5)
79 if file:
80 id_ = file.id
81 else:
82 address = self.fs.put(abs_path)
83 id_ = address.id
84 source_info.append([source_name, id_, handle])
85 return source_info
86
87 def queued_event(
88 self, ex_info, command, host_info, queue_time, config, meta_info, _id
89 ):
90 raise NotImplementedError(
91 "queued_event method is not implemented for" " local TinyDbObserver."
92 )
93
94 def started_event(
95 self, ex_info, command, host_info, start_time, config, meta_info, _id
96 ):
97 self.db_run_id = None
98
99 self.run_entry = {
100 "experiment": dict(ex_info),
101 "format": self.VERSION,
102 "command": command,
103 "host": dict(host_info),
104 "start_time": start_time,
105 "config": config,
106 "meta": meta_info,
107 "status": "RUNNING",
108 "resources": [],
109 "artifacts": [],
110 "captured_out": "",
111 "info": {},
112 "heartbeat": None,
113 }
114
115 # set ID if not given
116 if _id is None:
117 _id = uuid.uuid4().hex
118
119 self.run_entry["_id"] = _id
120
121 # save sources
122 self.run_entry["experiment"]["sources"] = self.save_sources(ex_info)
123 self.save()
124 return self.run_entry["_id"]
125
126 def heartbeat_event(self, info, captured_out, beat_time, result):
127 self.run_entry["info"] = info
128 self.run_entry["captured_out"] = captured_out
129 self.run_entry["heartbeat"] = beat_time
130 self.run_entry["result"] = result
131 self.save()
132
133 def completed_event(self, stop_time, result):
134 self.run_entry["stop_time"] = stop_time
135 self.run_entry["result"] = result
136 self.run_entry["status"] = "COMPLETED"
137 self.save()
138
139 def interrupted_event(self, interrupt_time, status):
140 self.run_entry["stop_time"] = interrupt_time
141 self.run_entry["status"] = status
142 self.save()
143
144 def failed_event(self, fail_time, fail_trace):
145 self.run_entry["stop_time"] = fail_time
146 self.run_entry["status"] = "FAILED"
147 self.run_entry["fail_trace"] = fail_trace
148 self.save()
149
150 def resource_event(self, filename):
151 from .bases import BufferedReaderWrapper
152
153 id_ = self.fs.put(filename).id
154 handle = BufferedReaderWrapper(open(filename, "rb"))
155 resource = [filename, id_, handle]
156
157 if resource not in self.run_entry["resources"]:
158 self.run_entry["resources"].append(resource)
159 self.save()
160
161 def artifact_event(self, name, filename, metadata=None, content_type=None):
162 from .bases import BufferedReaderWrapper
163
164 id_ = self.fs.put(filename).id
165 handle = BufferedReaderWrapper(open(filename, "rb"))
166 artifact = [name, filename, id_, handle]
167
168 if artifact not in self.run_entry["artifacts"]:
169 self.run_entry["artifacts"].append(artifact)
170 self.save()
171
172 def __eq__(self, other):
173 if isinstance(other, TinyDbObserver):
174 return self.runs.all() == other.runs.all()
175 return False
176
177
178@cli_option("-t", "--tiny_db")
179def tiny_db_option(args, run):
180 """Add a TinyDB Observer to the experiment.
181
182 The argument is the path to be given to the TinyDbObserver.
183 """
184 tinydb_obs = TinyDbObserver(path=args)
185 run.observers.append(tinydb_obs)
186
187
188class TinyDbReader:
189 def __init__(self, path):
190 from .bases import get_db_file_manager
191
192 root_dir = os.path.abspath(path)
193 if not os.path.exists(root_dir):
194 raise IOError("Path does not exist: %s" % path)
195
196 db, fs = get_db_file_manager(root_dir)
197
198 self.db = db
199 self.runs = db.table("runs")
200 self.fs = fs
201
202 def search(self, *args, **kwargs):
203 """Wrapper to TinyDB's search function."""
204 return self.runs.search(*args, **kwargs)
205
206 def fetch_files(self, exp_name=None, query=None, indices=None):
207 """Return Dictionary of files for experiment name or query.
208
209 Returns a list of one dictionary per matched experiment. The
210 dictionary is of the following structure
211
212 {
213 'exp_name': 'scascasc',
214 'exp_id': 'dqwdqdqwf',
215 'date': datatime_object,
216 'sources': [ {'filename': filehandle}, ..., ],
217 'resources': [ {'filename': filehandle}, ..., ],
218 'artifacts': [ {'filename': filehandle}, ..., ]
219 }
220
221 """
222 entries = self.fetch_metadata(exp_name, query, indices)
223
224 all_matched_entries = []
225 for ent in entries:
226
227 rec = dict(
228 exp_name=ent["experiment"]["name"],
229 exp_id=ent["_id"],
230 date=ent["start_time"],
231 )
232
233 source_files = {x[0]: x[2] for x in ent["experiment"]["sources"]}
234 resource_files = {x[0]: x[2] for x in ent["resources"]}
235 artifact_files = {x[0]: x[3] for x in ent["artifacts"]}
236
237 if source_files:
238 rec["sources"] = source_files
239 if resource_files:
240 rec["resources"] = resource_files
241 if artifact_files:
242 rec["artifacts"] = artifact_files
243
244 all_matched_entries.append(rec)
245
246 return all_matched_entries
247
248 def fetch_report(self, exp_name=None, query=None, indices=None):
249
250 template = """
251-------------------------------------------------
252Experiment: {exp_name}
253-------------------------------------------------
254ID: {exp_id}
255Date: {start_date} Duration: {duration}
256
257Parameters:
258{parameters}
259
260Result:
261{result}
262
263Dependencies:
264{dependencies}
265
266Resources:
267{resources}
268
269Source Files:
270{sources}
271
272Outputs:
273{artifacts}
274"""
275
276 entries = self.fetch_metadata(exp_name, query, indices)
277
278 all_matched_entries = []
279 for ent in entries:
280
281 date = ent["start_time"]
282 weekdays = "Mon Tue Wed Thu Fri Sat Sun".split()
283 w = weekdays[date.weekday()]
284 date = " ".join([w, date.strftime("%d %b %Y")])
285
286 duration = ent["stop_time"] - ent["start_time"]
287 secs = duration.total_seconds()
288 hours, remainder = divmod(secs, 3600)
289 minutes, seconds = divmod(remainder, 60)
290 duration = "%02d:%02d:%04.1f" % (hours, minutes, seconds)
291
292 parameters = self._dict_to_indented_list(ent["config"])
293
294 result = self._indent(ent["result"].__repr__(), prefix=" ")
295
296 deps = ent["experiment"]["dependencies"]
297 deps = self._indent("\n".join(deps), prefix=" ")
298
299 resources = [x[0] for x in ent["resources"]]
300 resources = self._indent("\n".join(resources), prefix=" ")
301
302 sources = [x[0] for x in ent["experiment"]["sources"]]
303 sources = self._indent("\n".join(sources), prefix=" ")
304
305 artifacts = [x[0] for x in ent["artifacts"]]
306 artifacts = self._indent("\n".join(artifacts), prefix=" ")
307
308 none_str = " None"
309
310 rec = dict(
311 exp_name=ent["experiment"]["name"],
312 exp_id=ent["_id"],
313 start_date=date,
314 duration=duration,
315 parameters=parameters if parameters else none_str,
316 result=result if result else none_str,
317 dependencies=deps if deps else none_str,
318 resources=resources if resources else none_str,
319 sources=sources if sources else none_str,
320 artifacts=artifacts if artifacts else none_str,
321 )
322
323 report = template.format(**rec)
324
325 all_matched_entries.append(report)
326
327 return all_matched_entries
328
329 def fetch_metadata(self, exp_name=None, query=None, indices=None):
330 """Return all metadata for matching experiment name, index or query."""
331 from tinydb import Query
332
333 if exp_name or query:
334 if query:
335 q = query
336 elif exp_name:
337 q = Query().experiment.name.search(exp_name)
338
339 entries = self.runs.search(q)
340
341 elif indices or indices == 0:
342 if not isinstance(indices, (tuple, list)):
343 indices = [indices]
344
345 num_recs = len(self.runs)
346
347 for idx in indices:
348 if idx >= num_recs:
349 raise ValueError(
350 "Index value ({}) must be less than "
351 "number of records ({})".format(idx, num_recs)
352 )
353
354 entries = [self.runs.all()[ind] for ind in indices]
355
356 else:
357 raise ValueError(
358 "Must specify an experiment name, indicies or " "pass custom query"
359 )
360
361 return entries
362
363 def _dict_to_indented_list(self, d):
364
365 d = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
366
367 output_str = ""
368
369 for k, v in d.items():
370 output_str += "%s: %s" % (k, v)
371 output_str += "\n"
372
373 output_str = self._indent(output_str.strip(), prefix=" ")
374
375 return output_str
376
377 def _indent(self, message, prefix):
378 """Wrapper for indenting strings in Python 2 and 3."""
379 preferred_width = 150
380 wrapper = textwrap.TextWrapper(
381 initial_indent=prefix, width=preferred_width, subsequent_indent=prefix
382 )
383
384 lines = message.splitlines()
385 formatted_lines = [wrapper.fill(lin) for lin in lines]
386 formatted_text = "\n".join(formatted_lines)
387
388 return formatted_text