Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/observers/s3_observer.py: 18%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2import os
3import os.path
5from sacred.commandline_options import cli_option
6from sacred.dependencies import get_digest
7from sacred.observers.base import RunObserver
8from sacred.serializer import flatten
9import re
10import socket
12DEFAULT_S3_PRIORITY = 20
15def _is_valid_bucket(bucket_name):
16 # See https://docs.aws.amazon.com/awscloudtrail/latest/userguide/
17 # cloudtrail-s3-bucket-naming-requirements.html
18 if len(bucket_name) < 3 or len(bucket_name) > 63:
19 return False
21 labels = bucket_name.split(".")
22 # A bucket name consists of "labels" separated by periods
23 for label in labels:
24 if len(label) == 0 or label[0] == "-" or label[-1] == "-":
25 # Labels must be of nonzero length,
26 # and cannot begin or end with a hyphen
27 return False
28 for char in label:
29 # Labels can only contain digits, lowercase letters, or hyphens.
30 # Anything else will fail here
31 if not (char.isdigit() or char.islower() or char == "-"):
32 return False
33 try:
34 # If a name is a valid IP address, it cannot be a bucket name
35 socket.inet_aton(bucket_name)
36 except socket.error:
37 return True
40def s3_join(*args):
41 return "/".join(args)
44class S3Observer(RunObserver):
45 VERSION = "S3Observer-0.1.0"
47 def __init__(
48 self,
49 bucket,
50 basedir,
51 resource_dir=None,
52 source_dir=None,
53 priority=DEFAULT_S3_PRIORITY,
54 region=None,
55 ):
56 """Constructor for a S3Observer object.
58 Run when the object is first created,
59 before it's used within an experiment.
61 Parameters
62 ----------
63 bucket
64 The name of the bucket you want to store results in.
65 Doesn't need to contain `s3://`, but needs to be a valid bucket name
66 basedir
67 The relative path inside your bucket where you want this experiment to store results
68 resource_dir
69 Where to store resources for this experiment. By
70 default, will be <basedir>/_resources
71 source_dir
72 Where to store code sources for this experiment. By
73 default, will be <basedir>/sources
74 priority
75 The priority to assign to this observer if
76 multiple observers are present
77 region
78 The AWS region in which you want to create and access
79 buckets. Needs to be either set here or configured in your AWS
80 """
81 import boto3
83 if not _is_valid_bucket(bucket):
84 raise ValueError(
85 "Your chosen bucket name doesn't follow AWS bucket naming rules"
86 )
87 resource_dir = resource_dir or "/".join([basedir, "_resources"])
88 source_dir = source_dir or "/".join([basedir, "_sources"])
90 self.basedir = basedir
91 self.bucket = bucket
92 # Keeping the convention of referring to locations in S3 as `dir`
93 # because that is a useful mental model and there isn't a better word
94 self.resource_dir = resource_dir
95 self.source_dir = source_dir
96 self.priority = priority
97 self.dir = None
98 self.run_entry = None
99 self.config = None
100 self.info = None
101 self.cout = ""
102 self.cout_write_cursor = 0
103 self.saved_metrics = {}
104 if region is not None:
105 self.region = region
106 self.s3 = boto3.resource("s3", region_name=region)
107 else:
108 session = boto3.session.Session()
109 if session.region_name is not None:
110 self.region = session.region_name
111 self.s3 = boto3.resource("s3")
112 else:
113 raise ValueError(
114 "You must either pass in an AWS region name, or have a "
115 "region name specified in your AWS config file"
116 )
118 def _objects_exist_in_dir(self, prefix):
119 # This should be run after you've confirmed the bucket
120 # exists, and will error out if it does not exist
122 bucket = self.s3.Bucket(self.bucket)
123 all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)]
124 return len(all_keys) > 0
126 def _bucket_exists(self):
127 from botocore.errorfactory import ClientError
129 try:
130 self.s3.meta.client.head_bucket(Bucket=self.bucket)
131 except ClientError as er:
132 if er.response["Error"]["Code"] == "404":
133 return False
134 return True
136 def _list_s3_subdirs(self, prefix=None):
137 if prefix is None:
138 prefix = self.basedir
139 bucket = self.s3.Bucket(self.bucket)
140 all_keys = [obj.key for obj in bucket.objects.filter(Prefix=prefix)]
141 subdir_match = r"{prefix}\/(.*)\/".format(prefix=prefix)
142 subdirs = []
143 for key in all_keys:
144 match_obj = re.match(subdir_match, key)
145 if match_obj is None:
146 continue
147 else:
148 subdirs.append(match_obj.groups()[0])
149 distinct_subdirs = set(subdirs)
150 return list(distinct_subdirs)
152 def _create_bucket(self):
153 bucket_response = self.s3.create_bucket(
154 Bucket=self.bucket,
155 CreateBucketConfiguration={"LocationConstraint": self.region},
156 )
157 return bucket_response
159 def _determine_run_dir(self, _id):
160 if _id is None:
161 bucket_exists = self._bucket_exists()
162 if not bucket_exists:
163 self._create_bucket()
164 bucket_path_subdirs = []
165 else:
166 bucket_path_subdirs = self._list_s3_subdirs()
168 if not bucket_path_subdirs:
169 max_run_id = 0
170 else:
171 integer_directories = [
172 int(d) for d in bucket_path_subdirs if d.isdigit()
173 ]
174 if not integer_directories:
175 max_run_id = 0
176 else:
177 # If there are directories under basedir that aren't
178 # numeric run directories, ignore those
179 max_run_id = max(integer_directories)
181 _id = max_run_id + 1
183 self.dir = s3_join(self.basedir, str(_id))
184 if self._objects_exist_in_dir(self.dir):
185 raise FileExistsError("S3 dir at {} already exists".format(self.dir))
186 return _id
188 def queued_event(
189 self, ex_info, command, host_info, queue_time, config, meta_info, _id
190 ):
191 _id = self._determine_run_dir(_id)
193 self.run_entry = {
194 "experiment": dict(ex_info),
195 "command": command,
196 "host": dict(host_info),
197 "meta": meta_info,
198 "status": "QUEUED",
199 }
200 self.config = config
201 self.info = {}
203 self.save_json(self.run_entry, "run.json")
204 self.save_json(self.config, "config.json")
206 for s, m in ex_info["sources"]:
207 self.save_file(s)
209 return _id
211 def save_sources(self, ex_info):
212 base_dir = ex_info["base_dir"]
213 source_info = []
214 for s, m in ex_info["sources"]:
215 abspath = os.path.join(base_dir, s)
216 store_path, md5sum = self.find_or_save(abspath, self.source_dir)
217 source_info.append([s, os.path.relpath(store_path, self.basedir)])
218 return source_info
220 def started_event(
221 self, ex_info, command, host_info, start_time, config, meta_info, _id
222 ):
224 _id = self._determine_run_dir(_id)
226 ex_info["sources"] = self.save_sources(ex_info)
228 self.run_entry = {
229 "experiment": dict(ex_info),
230 "command": command,
231 "host": dict(host_info),
232 "start_time": start_time.isoformat(),
233 "meta": meta_info,
234 "status": "RUNNING",
235 "resources": [],
236 "artifacts": [],
237 "heartbeat": None,
238 }
239 self.config = config
240 self.info = {}
241 self.cout = ""
242 self.cout_write_cursor = 0
244 self.save_json(self.run_entry, "run.json")
245 self.save_json(self.config, "config.json")
246 self.save_cout()
248 return _id
250 def find_or_save(self, filename, store_dir):
251 source_name, ext = os.path.splitext(os.path.basename(filename))
252 md5sum = get_digest(filename)
253 store_name = source_name + "_" + md5sum + ext
254 store_path = s3_join(store_dir, store_name)
255 if len(self._list_s3_subdirs(prefix=store_path)) == 0:
256 self.save_file(filename, store_path)
257 return store_path, md5sum
259 def put_data(self, key, binary_data):
260 self.s3.Object(self.bucket, key).put(Body=binary_data)
262 def save_json(self, obj, filename):
263 key = s3_join(self.dir, filename)
264 self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2))
266 def save_file(self, filename, target_name=None):
267 target_name = target_name or os.path.basename(filename)
268 key = s3_join(self.dir, target_name)
269 self.put_data(key, open(filename, "rb"))
271 def save_directory(self, source_dir, target_name):
272 import boto3
274 # Stolen from:
275 # https://github.com/boto/boto3/issues/358#issuecomment-346093506
276 target_name = target_name or os.path.basename(source_dir)
277 all_files = []
278 for root, dirs, files in os.walk(source_dir):
279 all_files += [os.path.join(root, f) for f in files]
280 s3_resource = boto3.resource("s3")
282 for filename in all_files:
283 file_location = s3_join(
284 self.dir, target_name, os.path.relpath(filename, source_dir)
285 )
286 s3_resource.Object(self.bucket, file_location).put(
287 Body=open(filename, "rb")
288 )
290 def save_cout(self):
291 binary_data = self.cout[self.cout_write_cursor :].encode("utf-8")
292 key = s3_join(self.dir, "cout.txt")
293 self.put_data(key, binary_data)
294 self.cout_write_cursor = len(self.cout)
296 def heartbeat_event(self, info, captured_out, beat_time, result):
297 self.info = info
298 self.run_entry["heartbeat"] = beat_time.isoformat()
299 self.run_entry["result"] = result
300 self.cout = captured_out
301 self.save_cout()
302 self.save_json(self.run_entry, "run.json")
303 if self.info:
304 self.save_json(self.info, "info.json")
306 def completed_event(self, stop_time, result):
307 self.run_entry["stop_time"] = stop_time.isoformat()
308 self.run_entry["result"] = result
309 self.run_entry["status"] = "COMPLETED"
311 self.save_json(self.run_entry, "run.json")
313 def interrupted_event(self, interrupt_time, status):
314 self.run_entry["stop_time"] = interrupt_time.isoformat()
315 self.run_entry["status"] = status
316 self.save_json(self.run_entry, "run.json")
318 def failed_event(self, fail_time, fail_trace):
319 self.run_entry["stop_time"] = fail_time.isoformat()
320 self.run_entry["status"] = "FAILED"
321 self.run_entry["fail_trace"] = fail_trace
322 self.save_json(self.run_entry, "run.json")
324 def resource_event(self, filename):
325 store_path, md5sum = self.find_or_save(filename, self.resource_dir)
326 self.run_entry["resources"].append([filename, store_path])
327 self.save_json(self.run_entry, "run.json")
329 def artifact_event(self, name, filename, metadata=None, content_type=None):
330 self.save_file(filename, name)
331 self.run_entry["artifacts"].append(name)
332 self.save_json(self.run_entry, "run.json")
334 def log_metrics(self, metrics_by_name, info):
335 """Store new measurements into metrics.json."""
336 for metric_name, metric_ptr in metrics_by_name.items():
338 if metric_name not in self.saved_metrics:
339 self.saved_metrics[metric_name] = {
340 "values": [],
341 "steps": [],
342 "timestamps": [],
343 }
345 self.saved_metrics[metric_name]["values"] += metric_ptr["values"]
346 self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"]
348 timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]]
349 self.saved_metrics[metric_name]["timestamps"] += timestamps_norm
351 self.save_json(self.saved_metrics, "metrics.json")
353 def __eq__(self, other):
354 if isinstance(other, S3Observer):
355 return self.bucket == other.bucket and self.basedir == other.basedir
356 else:
357 return False
360@cli_option("-S", "--s3")
361def s3_option(args, run):
362 """Add a S3 File observer to the experiment.
364 The argument value should be `s3://<bucket>/path/to/exp`.
365 """
366 match_obj = re.match(r"s3:\/\/([^\/]*)\/(.*)", args)
367 if match_obj is None or len(match_obj.groups()) != 2:
368 raise ValueError(
369 "Valid bucket specification not found. "
370 "Enter bucket and directory path like: "
371 "s3://<bucket>/path/to/exp"
372 )
373 bucket, basedir = match_obj.groups()
374 run.observers.append(S3Observer(bucket=bucket, basedir=basedir))