Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/observers/mongo.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Optional, Union
2import mimetypes
3import os.path
4import pickle
5import re
6import sys
7import time
8from tempfile import NamedTemporaryFile
9import warnings
11import sacred.optional as opt
12from sacred.commandline_options import cli_option
13from sacred.dependencies import get_digest
14from sacred.observers.base import RunObserver
15from sacred.observers.queue import QueueObserver
16from sacred.serializer import flatten
17from sacred.utils import ObserverError, PathType
18import pkg_resources
20DEFAULT_MONGO_PRIORITY = 30
22# This ensures consistent mimetype detection across platforms.
23mimetype_detector = mimetypes.MimeTypes(
24 filenames=[pkg_resources.resource_filename("sacred", "data/mime.types")]
25)
28def force_valid_bson_key(key):
29 key = str(key)
30 if key.startswith("$"):
31 key = "@" + key[1:]
32 key = key.replace(".", ",")
33 return key
36def force_bson_encodeable(obj):
37 import bson
39 if isinstance(obj, dict):
40 try:
41 bson.BSON.encode(obj, check_keys=True)
42 return obj
43 except bson.InvalidDocument:
44 return {
45 force_valid_bson_key(k): force_bson_encodeable(v)
46 for k, v in obj.items()
47 }
49 elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
50 return obj
51 else:
52 try:
53 bson.BSON.encode({"dict_just_for_testing": obj})
54 return obj
55 except bson.InvalidDocument:
56 return str(obj)
59class MongoObserver(RunObserver):
60 COLLECTION_NAME_BLACKLIST = {
61 "fs.files",
62 "fs.chunks",
63 "_properties",
64 "system.indexes",
65 "search_space",
66 "search_spaces",
67 }
68 VERSION = "MongoObserver-0.7.0"
70 @classmethod
71 def create(cls, *args, **kwargs):
72 warnings.warn(
73 "MongoObserver.create(...) is deprecated. "
74 "Please use MongoObserver(...) instead.",
75 DeprecationWarning,
76 )
77 return cls(*args, **kwargs)
79 def __init__(
80 self,
81 url: Optional[str] = None,
82 db_name: str = "sacred",
83 collection: str = "runs",
84 collection_prefix: str = "",
85 overwrite: Optional[Union[int, str]] = None,
86 priority: int = DEFAULT_MONGO_PRIORITY,
87 client: Optional["pymongo.MongoClient"] = None,
88 failure_dir: Optional[PathType] = None,
89 **kwargs,
90 ):
91 """Initializer for MongoObserver.
93 Parameters
94 ----------
95 url
96 Mongo URI to connect to.
97 db_name
98 Database to connect to.
99 collection
100 Collection to write the runs to. (default: "runs").
101 **DEPRECATED**, please use collection_prefix instead.
102 collection_prefix
103 Prefix the runs and metrics collection,
104 i.e. runs will be stored to PREFIX_runs, metrics to PREFIX_metrics.
105 If empty runs are stored to 'runs', metrics to 'metrics'.
106 overwrite
107 _id of a run that should be overwritten.
108 priority
109 (default 30)
110 client
111 Client to connect to. Do not use client and URL together.
112 failure_dir
113 Directory to save the run of a failed observer to.
114 """
115 import pymongo
116 import gridfs
118 if client is not None:
119 if not isinstance(client, pymongo.MongoClient):
120 raise ValueError(
121 "client needs to be a pymongo.MongoClient, "
122 "but is {} instead".format(type(client))
123 )
124 if url is not None:
125 raise ValueError("Cannot pass both a client and a url.")
126 else:
127 client = pymongo.MongoClient(url, **kwargs)
129 database = client[db_name]
130 if collection != "runs":
131 # the 'old' way of setting a custom collection name
132 # still works as before for backward compatibility
133 warnings.warn(
134 'Argument "collection" is deprecated. '
135 'Please use "collection_prefix" instead.',
136 DeprecationWarning,
137 )
138 if collection_prefix != "":
139 raise ValueError("Cannot pass both collection and a collection prefix.")
140 runs_collection_name = collection
141 metrics_collection_name = "metrics"
142 else:
143 if collection_prefix != "":
144 # separate prefix from 'runs' / 'collections' by an underscore.
145 collection_prefix = "{}_".format(collection_prefix)
147 runs_collection_name = "{}runs".format(collection_prefix)
148 metrics_collection_name = "{}metrics".format(collection_prefix)
150 if runs_collection_name in MongoObserver.COLLECTION_NAME_BLACKLIST:
151 raise KeyError(
152 'Collection name "{}" is reserved. '
153 "Please use a different one.".format(runs_collection_name)
154 )
156 if metrics_collection_name in MongoObserver.COLLECTION_NAME_BLACKLIST:
157 raise KeyError(
158 'Collection name "{}" is reserved. '
159 "Please use a different one.".format(metrics_collection_name)
160 )
162 runs_collection = database[runs_collection_name]
163 metrics_collection = database[metrics_collection_name]
164 fs = gridfs.GridFS(database)
165 self.initialize(
166 runs_collection,
167 fs,
168 overwrite=overwrite,
169 metrics_collection=metrics_collection,
170 failure_dir=failure_dir,
171 priority=priority,
172 )
174 def initialize(
175 self,
176 runs_collection,
177 fs,
178 overwrite=None,
179 metrics_collection=None,
180 failure_dir=None,
181 priority=DEFAULT_MONGO_PRIORITY,
182 ):
183 self.runs = runs_collection
184 self.metrics = metrics_collection
185 self.fs = fs
186 if overwrite is not None:
187 overwrite = int(overwrite)
188 run = self.runs.find_one({"_id": overwrite})
189 if run is None:
190 raise RuntimeError(
191 "Couldn't find run to overwrite with " "_id='{}'".format(overwrite)
192 )
193 else:
194 overwrite = run
195 self.overwrite = overwrite
196 self.run_entry = None
197 self.priority = priority
198 self.failure_dir = failure_dir
200 @classmethod
201 def create_from(cls, *args, **kwargs):
202 self = cls.__new__(cls) # skip __init__ call
203 self.initialize(*args, **kwargs)
204 return self
206 def queued_event(
207 self, ex_info, command, host_info, queue_time, config, meta_info, _id
208 ):
209 if self.overwrite is not None:
210 raise RuntimeError("Can't overwrite with QUEUED run.")
211 self.run_entry = {
212 "experiment": dict(ex_info),
213 "command": command,
214 "host": dict(host_info),
215 "config": flatten(config),
216 "meta": meta_info,
217 "status": "QUEUED",
218 }
219 # set ID if given
220 if _id is not None:
221 self.run_entry["_id"] = _id
222 # save sources
223 self.run_entry["experiment"]["sources"] = self.save_sources(ex_info)
224 self.insert()
225 return self.run_entry["_id"]
227 def started_event(
228 self, ex_info, command, host_info, start_time, config, meta_info, _id
229 ):
230 if self.overwrite is None:
231 self.run_entry = {"_id": _id}
232 else:
233 if self.run_entry is not None:
234 raise RuntimeError("Cannot overwrite more than once!")
235 # TODO sanity checks
236 self.run_entry = self.overwrite
238 self.run_entry.update(
239 {
240 "experiment": dict(ex_info),
241 "format": self.VERSION,
242 "command": command,
243 "host": dict(host_info),
244 "start_time": start_time,
245 "config": flatten(config),
246 "meta": meta_info,
247 "status": "RUNNING",
248 "resources": [],
249 "artifacts": [],
250 "captured_out": "",
251 "info": {},
252 "heartbeat": None,
253 }
254 )
256 # save sources
257 self.run_entry["experiment"]["sources"] = self.save_sources(ex_info)
258 self.insert()
259 return self.run_entry["_id"]
261 def heartbeat_event(self, info, captured_out, beat_time, result):
262 self.run_entry["info"] = flatten(info)
263 self.run_entry["captured_out"] = captured_out
264 self.run_entry["heartbeat"] = beat_time
265 self.run_entry["result"] = flatten(result)
266 self.save()
268 def completed_event(self, stop_time, result):
269 self.run_entry["stop_time"] = stop_time
270 self.run_entry["result"] = flatten(result)
271 self.run_entry["status"] = "COMPLETED"
272 self.final_save(attempts=10)
274 def interrupted_event(self, interrupt_time, status):
275 self.run_entry["stop_time"] = interrupt_time
276 self.run_entry["status"] = status
277 self.final_save(attempts=3)
279 def failed_event(self, fail_time, fail_trace):
280 self.run_entry["stop_time"] = fail_time
281 self.run_entry["status"] = "FAILED"
282 self.run_entry["fail_trace"] = fail_trace
283 self.final_save(attempts=1)
285 def resource_event(self, filename):
286 if self.fs.exists(filename=filename):
287 md5hash = get_digest(filename)
288 if self.fs.exists(filename=filename, md5=md5hash):
289 resource = (filename, md5hash)
290 if resource not in self.run_entry["resources"]:
291 self.run_entry["resources"].append(resource)
292 self.save()
293 return
294 with open(filename, "rb") as f:
295 file_id = self.fs.put(f, filename=filename)
296 md5hash = self.fs.get(file_id).md5
297 self.run_entry["resources"].append((filename, md5hash))
298 self.save()
300 def artifact_event(self, name, filename, metadata=None, content_type=None):
301 with open(filename, "rb") as f:
302 run_id = self.run_entry["_id"]
303 db_filename = "artifact://{}/{}/{}".format(self.runs.name, run_id, name)
304 if content_type is None:
305 content_type = self._try_to_detect_content_type(filename)
307 file_id = self.fs.put(
308 f, filename=db_filename, metadata=metadata, content_type=content_type
309 )
311 self.run_entry["artifacts"].append({"name": name, "file_id": file_id})
312 self.save()
314 @staticmethod
315 def _try_to_detect_content_type(filename):
316 mime_type, _ = mimetype_detector.guess_type(filename)
317 if mime_type is not None:
318 print(
319 "Added {} as content-type of artifact {}.".format(mime_type, filename)
320 )
321 else:
322 print(
323 "Failed to detect content-type automatically for "
324 "artifact {}.".format(filename)
325 )
326 return mime_type
328 def log_metrics(self, metrics_by_name, info):
329 """Store new measurements to the database.
331 Take measurements and store them into
332 the metrics collection in the database.
333 Additionally, reference the metrics
334 in the info["metrics"] dictionary.
335 """
336 if self.metrics is None:
337 # If, for whatever reason, the metrics collection has not been set
338 # do not try to save anything there.
339 return
340 for key in metrics_by_name:
341 query = {"run_id": self.run_entry["_id"], "name": key}
342 push = {
343 "steps": {"$each": metrics_by_name[key]["steps"]},
344 "values": {"$each": metrics_by_name[key]["values"]},
345 "timestamps": {"$each": metrics_by_name[key]["timestamps"]},
346 }
347 update = {"$push": push}
348 result = self.metrics.update_one(query, update, upsert=True)
349 if result.upserted_id is not None:
350 # This is the first time we are storing this metric
351 info.setdefault("metrics", []).append(
352 {"name": key, "id": str(result.upserted_id)}
353 )
355 def insert(self):
356 import pymongo.errors
358 if self.overwrite:
359 return self.save()
361 autoinc_key = self.run_entry.get("_id") is None
362 while True:
363 if autoinc_key:
364 c = self.runs.find({}, {"_id": 1})
365 c = c.sort("_id", pymongo.DESCENDING).limit(1)
366 self.run_entry["_id"] = (
367 c.next()["_id"] + 1 if self.runs.count_documents({}, limit=1) else 1
368 )
369 try:
370 self.runs.insert_one(self.run_entry)
371 return
372 except pymongo.errors.InvalidDocument as e:
373 raise ObserverError(
374 "Run contained an unserializable entry."
375 "(most likely in the info)\n{}".format(e)
376 )
377 except pymongo.errors.DuplicateKeyError:
378 if not autoinc_key:
379 raise
381 def save(self):
382 import pymongo.errors
384 try:
385 self.runs.update_one(
386 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry}
387 )
388 except pymongo.errors.AutoReconnect:
389 pass # just wait for the next save
390 except pymongo.errors.InvalidDocument:
391 raise ObserverError(
392 "Run contained an unserializable entry." "(most likely in the info)"
393 )
395 def final_save(self, attempts):
396 import pymongo.errors
398 for i in range(attempts):
399 try:
400 self.runs.update_one(
401 {"_id": self.run_entry["_id"]},
402 {"$set": self.run_entry},
403 upsert=True,
404 )
405 return
406 except pymongo.errors.AutoReconnect:
407 if i < attempts - 1:
408 time.sleep(1)
409 except pymongo.errors.ConnectionFailure:
410 pass
411 except pymongo.errors.InvalidDocument:
412 self.run_entry = force_bson_encodeable(self.run_entry)
413 print(
414 "Warning: Some of the entries of the run were not "
415 "BSON-serializable!\n They have been altered such that "
416 "they can be stored, but you should fix your experiment!"
417 "Most likely it is either the 'info' or the 'result'.",
418 file=sys.stderr,
419 )
421 os.makedirs(self.failure_dir, exist_ok=True)
422 with NamedTemporaryFile(
423 suffix=".pickle",
424 delete=False,
425 prefix="sacred_mongo_fail_{}_".format(self.run_entry["_id"]),
426 dir=self.failure_dir,
427 ) as f:
428 pickle.dump(self.run_entry, f)
429 print(
430 "Warning: saving to MongoDB failed! "
431 "Stored experiment entry in '{}'".format(f.name),
432 file=sys.stderr,
433 )
435 def save_sources(self, ex_info):
436 base_dir = ex_info["base_dir"]
437 source_info = []
438 for source_name, md5 in ex_info["sources"]:
439 abs_path = os.path.join(base_dir, source_name)
440 file = self.fs.find_one({"filename": abs_path, "md5": md5})
441 if file:
442 _id = file._id
443 else:
444 with open(abs_path, "rb") as f:
445 _id = self.fs.put(f, filename=abs_path)
446 source_info.append([source_name, _id])
447 return source_info
449 def __eq__(self, other):
450 if isinstance(other, MongoObserver):
451 return self.runs == other.runs
452 return False
455@cli_option("-m", "--mongo_db")
456def mongo_db_option(args, run):
457 """Add a MongoDB Observer to the experiment.
459 The argument value is the database specification.
460 Should be in the form:
462 `[host:port:]db_name[.collection[:id]][!priority]`
463 """
464 kwargs = parse_mongo_db_arg(args)
465 mongo = MongoObserver(**kwargs)
466 run.observers.append(mongo)
469def get_pattern():
470 run_id_pattern = r"(?P<overwrite>\d{1,12})"
471 port1_pattern = r"(?P<port1>\d{1,5})"
472 port2_pattern = r"(?P<port2>\d{1,5})"
473 priority_pattern = r"(?P<priority>-?\d+)?"
474 db_name_pattern = r"(?P<db_name>[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
475 coll_name_pattern = (
476 r"(?P<collection>[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
477 )
478 hostname1_pattern = (
479 r"(?P<host1>"
480 r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"
481 r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}"
482 r"[0-9A-Za-z])?)*)"
483 )
484 hostname2_pattern = (
485 r"(?P<host2>"
486 r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"
487 r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}"
488 r"[0-9A-Za-z])?)*)"
489 )
491 host_only = r"^(?:{host}:{port})$".format(
492 host=hostname1_pattern, port=port1_pattern
493 )
494 full = (
495 r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?"
496 r"(?:!{priority})?$".format(
497 host=hostname2_pattern,
498 port=port2_pattern,
499 db=db_name_pattern,
500 collection=coll_name_pattern,
501 rid=run_id_pattern,
502 priority=priority_pattern,
503 )
504 )
506 return r"{host_only}|{full}".format(host_only=host_only, full=full)
509def parse_mongo_db_arg(mongo_db):
510 g = re.match(get_pattern(), mongo_db).groupdict()
511 if g is None:
512 raise ValueError(
513 'mongo_db argument must have the form "db_name" '
514 'or "host:port[:db_name]" but was {}'.format(mongo_db)
515 )
517 kwargs = {}
518 if g["host1"]:
519 kwargs["url"] = "{}:{}".format(g["host1"], g["port1"])
520 elif g["host2"]:
521 kwargs["url"] = "{}:{}".format(g["host2"], g["port2"])
523 if g["priority"] is not None:
524 kwargs["priority"] = int(g["priority"])
526 for p in ["db_name", "collection", "overwrite"]:
527 if g[p] is not None:
528 kwargs[p] = g[p]
530 return kwargs
533class QueueCompatibleMongoObserver(MongoObserver):
534 def log_metrics(self, metric_name, metrics_values, info):
535 """Store new measurements to the database.
537 Take measurements and store them into
538 the metrics collection in the database.
539 Additionally, reference the metrics
540 in the info["metrics"] dictionary.
541 """
542 if self.metrics is None:
543 # If, for whatever reason, the metrics collection has not been set
544 # do not try to save anything there.
545 return
546 query = {"run_id": self.run_entry["_id"], "name": metric_name}
547 push = {
548 "steps": {"$each": metrics_values["steps"]},
549 "values": {"$each": metrics_values["values"]},
550 "timestamps": {"$each": metrics_values["timestamps"]},
551 }
552 update = {"$push": push}
553 result = self.metrics.update_one(query, update, upsert=True)
554 if result.upserted_id is not None:
555 # This is the first time we are storing this metric
556 info.setdefault("metrics", []).append(
557 {"name": metric_name, "id": str(result.upserted_id)}
558 )
560 def save(self):
561 import pymongo
563 try:
564 self.runs.update_one(
565 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry}
566 )
567 except pymongo.errors.InvalidDocument as exc:
568 raise ObserverError(
569 "Run contained an unserializable entry. (most likely in the info)"
570 ) from exc
572 def final_save(self, attempts):
573 import pymongo
575 try:
576 self.runs.update_one(
577 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry}, upsert=True
578 )
579 return
581 except pymongo.errors.InvalidDocument:
582 self.run_entry = force_bson_encodeable(self.run_entry)
583 print(
584 "Warning: Some of the entries of the run were not "
585 "BSON-serializable!\n They have been altered such that "
586 "they can be stored, but you should fix your experiment!"
587 "Most likely it is either the 'info' or the 'result'.",
588 file=sys.stderr,
589 )
591 with NamedTemporaryFile(
592 suffix=".pickle", delete=False, prefix="sacred_mongo_fail_"
593 ) as f:
594 pickle.dump(self.run_entry, f)
595 print(
596 "Warning: saving to MongoDB failed! "
597 "Stored experiment entry in '{}'".format(f.name),
598 file=sys.stderr,
599 )
601 raise ObserverError("Warning: saving to MongoDB failed!")
604class QueuedMongoObserver(QueueObserver):
605 """MongoObserver that uses a fault-tolerant background process."""
607 @classmethod
608 def create(cls, *args, **kwargs):
609 warnings.warn(
610 "QueuedMongoObserver.create(...) is deprecated. "
611 "Please use QueuedMongoObserver(...) instead.",
612 DeprecationWarning,
613 )
614 return cls(*args, **kwargs)
616 def __init__(
617 self,
618 interval: float = 20.0,
619 retry_interval: float = 10.0,
620 url: Optional[str] = None,
621 db_name: str = "sacred",
622 collection: str = "runs",
623 overwrite: Optional[Union[int, str]] = None,
624 priority: int = DEFAULT_MONGO_PRIORITY,
625 client: Optional["pymongo.MongoClient"] = None,
626 **kwargs,
627 ):
628 """Initializer for MongoObserver.
630 Parameters
631 ----------
632 interval
633 The interval in seconds at which the background thread is woken up to
634 process new events.
635 retry_interval
636 The interval in seconds to wait if an event failed to be processed.
637 url
638 Mongo URI to connect to.
639 db_name
640 Database to connect to.
641 collection
642 Collection to write the runs to. (default: "runs").
643 overwrite
644 _id of a run that should be overwritten.
645 priority
646 (default 30)
647 client
648 Client to connect to. Do not use client and URL together.
649 failure_dir
650 Directory to save the run of a failed observer to.
651 """
652 super().__init__(
653 QueueCompatibleMongoObserver(
654 url=url,
655 db_name=db_name,
656 collection=collection,
657 overwrite=overwrite,
658 priority=priority,
659 client=client,
660 **kwargs,
661 ),
662 interval=interval,
663 retry_interval=retry_interval,
664 )