Coverage for sacred/sacred/observers/gcs_observer.py: 20%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2import os
3import os.path
4import re
5from typing import Optional
7from sacred.commandline_options import cli_option
8from sacred.dependencies import get_digest
9from sacred.observers.base import RunObserver
10from sacred.serializer import flatten
11from sacred.utils import PathType
13DEFAULT_GCS_PRIORITY = 20
16def _is_valid_bucket(bucket_name: str):
17 """Validates correctness of bucket naming.
19 Reference: https://cloud.google.com/storage/docs/naming
20 """
21 if bucket_name.startswith("gs://"):
22 return False
24 if len(bucket_name) < 3 or len(bucket_name) > 63:
25 return False
27 # IP address
28 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", bucket_name):
29 return False
31 if not re.fullmatch(r"([^A-Z]|-|_|[.]|)+", bucket_name):
32 return False
34 if ".." in bucket_name:
35 return False
37 if "goog" in bucket_name or "g00g" in bucket_name:
38 return False
40 return True
43def gcs_join(*args):
44 return "/".join(args)
47class GoogleCloudStorageObserver(RunObserver):
48 VERSION = "GoogleCloudStorageObserver-0.1.0"
50 def __init__(
51 self,
52 bucket: str,
53 basedir: PathType,
54 resource_dir: Optional[PathType] = None,
55 source_dir: Optional[PathType] = None,
56 priority: Optional[int] = DEFAULT_GCS_PRIORITY,
57 ):
58 """Constructor for a GoogleCloudStorageObserver object.
60 Run when the object is first created,
61 before it's used within an experiment.
63 Parameters
64 ----------
65 bucket
66 The name of the bucket you want to store results in.
67 Needs to be a valid bucket name without 'gs://'
68 basedir
69 The relative path inside your bucket where you want this experiment to store results
70 resource_dir
71 Where to store resources for this experiment. By
72 default, will be <basedir>/_resources
73 source_dir
74 Where to store code sources for this experiment. By
75 default, will be <basedir>/sources
76 priority
77 The priority to assign to this observer if
78 multiple observers are present
79 """
80 if not _is_valid_bucket(bucket):
81 raise ValueError(
82 "Your chosen bucket name doesn't follow Google Cloud Storage bucket naming rules"
83 )
84 resource_dir = resource_dir or "/".join([basedir, "_resources"])
85 source_dir = source_dir or "/".join([basedir, "_sources"])
87 self.basedir = basedir
88 self.bucket_id = bucket
90 self.resource_dir = resource_dir
91 self.source_dir = source_dir
92 self.priority = priority
93 self.dir = None
94 self.run_entry = None
95 self.config = None
96 self.info = None
97 self.cout = ""
98 self.cout_write_cursor = 0
99 self.saved_metrics = {}
101 from google.cloud import storage
102 import google.auth.exceptions
104 try:
105 client = storage.Client()
106 except google.auth.exceptions.DefaultCredentialsError:
107 raise ConnectionError(
108 "Could not create Google Cloud Storage observer, are you "
109 "sure that you have set environment variable GOOGLE_APPLICATION_CREDENTIALS?"
110 )
112 self.bucket = client.bucket(bucket)
114 def _objects_exist_in_dir(self, prefix):
115 # This should be run after you've confirmed the bucket
116 # exists, and will error out if it does not exist
117 all_blobs = [blob for blob in self.bucket.list_blobs(prefix=prefix)]
118 return len(all_blobs) > 0
120 def _list_gcs_subdirs(self, prefix=None):
121 if prefix is None:
122 prefix = self.basedir
124 iterator = self.bucket.list_blobs(prefix=prefix, delimiter="/")
125 prefixes = set()
126 for page in iterator.pages:
127 prefixes.update(page.prefixes)
129 return list(prefixes)
131 def _determine_run_dir(self, _id):
132 if _id is None:
133 basepath = os.path.join(self.basedir, "")
134 bucket_path_subdirs = self._list_gcs_subdirs(prefix=basepath)
136 if not bucket_path_subdirs:
137 max_run_id = 0
138 else:
139 relative_paths = [
140 path.replace(self.basedir, "").strip("/")
141 for path in bucket_path_subdirs
142 ]
143 integer_directories = [int(d) for d in relative_paths if d.isdigit()]
144 if not integer_directories:
145 max_run_id = 0
146 else:
147 # If there are directories under basedir that aren't
148 # numeric run directories, ignore those
149 max_run_id = max(integer_directories)
151 _id = max_run_id + 1
153 self.dir = gcs_join(self.basedir, str(_id))
154 if self._objects_exist_in_dir(self.dir):
155 raise FileExistsError("GCS dir at {} already exists".format(self.dir))
156 return _id
158 def queued_event(
159 self, ex_info, command, host_info, queue_time, config, meta_info, _id
160 ):
161 _id = self._determine_run_dir(_id)
163 self.run_entry = {
164 "experiment": dict(ex_info),
165 "command": command,
166 "host": dict(host_info),
167 "meta": meta_info,
168 "status": "QUEUED",
169 }
170 self.config = config
171 self.info = {}
173 self.save_json(self.run_entry, "run.json")
174 self.save_json(self.config, "config.json")
176 for s, m in ex_info["sources"]:
177 self.save_file(s)
179 return _id
181 def save_sources(self, ex_info):
182 base_dir = ex_info["base_dir"]
183 source_info = []
184 for s, m in ex_info["sources"]:
185 abspath = os.path.join(base_dir, s)
186 store_path, md5sum = self.find_or_save(abspath, self.source_dir)
187 source_info.append([s, os.path.relpath(store_path, self.basedir)])
188 return source_info
190 def started_event(
191 self, ex_info, command, host_info, start_time, config, meta_info, _id
192 ):
194 _id = self._determine_run_dir(_id)
196 ex_info["sources"] = self.save_sources(ex_info)
198 self.run_entry = {
199 "experiment": dict(ex_info),
200 "command": command,
201 "host": dict(host_info),
202 "start_time": start_time.isoformat(),
203 "meta": meta_info,
204 "status": "RUNNING",
205 "resources": [],
206 "artifacts": [],
207 "heartbeat": None,
208 }
209 self.config = config
210 self.info = {}
211 self.cout = ""
212 self.cout_write_cursor = 0
214 self.save_json(self.run_entry, "run.json")
215 self.save_json(self.config, "config.json")
216 self.save_cout()
218 return _id
220 def find_or_save(self, filename, store_dir):
221 source_name, ext = os.path.splitext(os.path.basename(filename))
222 md5sum = get_digest(filename)
223 store_name = source_name + "_" + md5sum + ext
224 store_path = gcs_join(store_dir, store_name)
225 if len(self._list_gcs_subdirs(prefix=store_path)) == 0:
226 self.save_file_to_base(filename, store_path)
227 return store_path, md5sum
229 def put_data(self, key, binary_data):
230 blob = self.bucket.blob(key)
231 blob.upload_from_file(binary_data)
233 def save_json(self, obj, filename):
234 key = gcs_join(self.dir, filename)
235 blob = self.bucket.blob(key)
236 blob.upload_from_string(
237 json.dumps(flatten(obj), sort_keys=True, indent=2), content_type="text/json"
238 )
240 def save_file(self, filename, target_name=None):
241 target_name = target_name or os.path.basename(filename)
242 key = gcs_join(self.dir, target_name)
243 self.put_data(key, open(filename, "rb"))
245 def save_file_to_base(self, filename, target_name=None):
246 target_name = target_name or os.path.basename(filename)
247 self.put_data(target_name, open(filename, "rb"))
249 def save_directory(self, source_dir, target_name):
250 target_name = target_name or os.path.basename(source_dir)
251 all_files = []
252 for root, dirs, files in os.walk(source_dir):
253 all_files += [os.path.join(root, f) for f in files]
255 for filename in all_files:
256 file_location = gcs_join(
257 self.dir, target_name, os.path.relpath(filename, source_dir)
258 )
259 self.put_data(file_location, open(filename, "rb"))
261 def save_cout(self):
262 binary_data = self.cout[self.cout_write_cursor :].encode("utf-8")
263 key = gcs_join(self.dir, "cout.txt")
264 blob = self.bucket.blob(key)
265 blob.upload_from_string(binary_data, content_type="text/plain")
266 self.cout_write_cursor = len(self.cout)
268 def heartbeat_event(self, info, captured_out, beat_time, result):
269 self.info = info
270 self.run_entry["heartbeat"] = beat_time.isoformat()
271 self.run_entry["result"] = result
272 self.cout = captured_out
273 self.save_cout()
274 self.save_json(self.run_entry, "run.json")
275 if self.info:
276 self.save_json(self.info, "info.json")
278 def completed_event(self, stop_time, result):
279 self.run_entry["stop_time"] = stop_time.isoformat()
280 self.run_entry["result"] = result
281 self.run_entry["status"] = "COMPLETED"
283 self.save_json(self.run_entry, "run.json")
285 def interrupted_event(self, interrupt_time, status):
286 self.run_entry["stop_time"] = interrupt_time.isoformat()
287 self.run_entry["status"] = status
288 self.save_json(self.run_entry, "run.json")
290 def failed_event(self, fail_time, fail_trace):
291 self.run_entry["stop_time"] = fail_time.isoformat()
292 self.run_entry["status"] = "FAILED"
293 self.run_entry["fail_trace"] = fail_trace
294 self.save_json(self.run_entry, "run.json")
296 def resource_event(self, filename):
297 store_path, md5sum = self.find_or_save(filename, self.resource_dir)
298 self.run_entry["resources"].append([filename, store_path])
299 self.save_json(self.run_entry, "run.json")
301 def artifact_event(self, name, filename, metadata=None, content_type=None):
302 self.save_file(filename, name)
303 self.run_entry["artifacts"].append(name)
304 self.save_json(self.run_entry, "run.json")
306 def log_metrics(self, metrics_by_name, info):
307 """Store new measurements into metrics.json."""
308 for metric_name, metric_ptr in metrics_by_name.items():
310 if metric_name not in self.saved_metrics:
311 self.saved_metrics[metric_name] = {
312 "values": [],
313 "steps": [],
314 "timestamps": [],
315 }
317 self.saved_metrics[metric_name]["values"] += metric_ptr["values"]
318 self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"]
320 timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]]
321 self.saved_metrics[metric_name]["timestamps"] += timestamps_norm
323 self.save_json(self.saved_metrics, "metrics.json")
325 def __eq__(self, other):
326 if isinstance(other, GoogleCloudStorageObserver):
327 return self.bucket_id == other.bucket_id and self.basedir == other.basedir
328 else:
329 return False
332@cli_option("-G", "--gcs")
333def gcs_option(args, run):
334 """Add a Google Cloud Storage File observer to the experiment.
336 The argument value should be `gs://<bucket>/path/to/exp`.
337 """
338 match_obj = re.match(r"gs:\/\/([^\/]*)\/(.*)", args)
339 if match_obj is None or len(match_obj.groups()) != 2:
340 raise ValueError(
341 "Valid bucket specification not found. "
342 "Enter bucket and directory path like: "
343 "gs://<bucket>/path/to/exp"
344 )
345 bucket, basedir = match_obj.groups()
346 run.observers.append(GoogleCloudStorageObserver(bucket=bucket, basedir=basedir))