Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/file_utils.py: 49%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Utilities for working with the local dataset cache. Copied from AllenNLP
3"""
4from pathlib import Path
5from typing import Tuple, Union, Optional, Sequence, cast
6import os
7import base64
8import logging
9import shutil
10import tempfile
11import re
12import functools
13from urllib.parse import urlparse
15import mmap
16import requests
17import zipfile
18import io
20# from allennlp.common.tqdm import Tqdm
21import flair
23logger = logging.getLogger("flair")
26def load_big_file(f: str) -> mmap.mmap:
27 """
28 Workaround for loading a big pickle file. Files over 2GB cause pickle errors on certin Mac and Windows distributions.
29 :param f:
30 :return:
31 """
32 logger.info(f"loading file {f}")
33 with open(f, "rb") as f_in:
34 # mmap seems to be much more memory efficient
35 bf = mmap.mmap(f_in.fileno(), 0, access=mmap.ACCESS_READ)
36 f_in.close()
37 return bf
40def url_to_filename(url: str, etag: str = None) -> str:
41 """
42 Converts a url into a filename in a reversible way.
43 If `etag` is specified, add it on the end, separated by a period
44 (which necessarily won't appear in the base64-encoded filename).
45 Get rid of the quotes in the etag, since Windows doesn't like them.
46 """
47 url_bytes = url.encode("utf-8")
48 b64_bytes = base64.b64encode(url_bytes)
49 decoded = b64_bytes.decode("utf-8")
51 if etag:
52 # Remove quotes from etag
53 etag = etag.replace('"', "")
54 return f"{decoded}.{etag}"
55 else:
56 return decoded
59def filename_to_url(filename: str) -> Tuple[str, str]:
60 """
61 Recovers the the url from the encoded filename. Returns it and the ETag
62 (which may be ``None``)
63 """
64 try:
65 # If there is an etag, it's everything after the first period
66 decoded, etag = filename.split(".", 1)
67 except ValueError:
68 # Otherwise, use None
69 decoded, etag = filename, None
71 filename_bytes = decoded.encode("utf-8")
72 url_bytes = base64.b64decode(filename_bytes)
73 return url_bytes.decode("utf-8"), etag
76def cached_path(url_or_filename: str, cache_dir: Union[str, Path]) -> Path:
77 """
78 Given something that might be a URL (or might be a local path),
79 determine which. If it's a URL, download the file and cache it, and
80 return the path to the cached file. If it's already a local path,
81 make sure the file exists and then return the path.
82 """
83 if type(cache_dir) is str:
84 cache_dir = Path(cache_dir)
85 dataset_cache = flair.cache_root / cache_dir
87 parsed = urlparse(url_or_filename)
89 if parsed.scheme in ("http", "https"):
90 # URL, so get it from the cache (downloading if necessary)
91 return get_from_cache(url_or_filename, dataset_cache)
92 elif parsed.scheme == "" and Path(url_or_filename).exists():
93 # File, and it exists.
94 return Path(url_or_filename)
95 elif parsed.scheme == "":
96 # File, but it doesn't exist.
97 raise FileNotFoundError("file {} not found".format(url_or_filename))
98 else:
99 # Something unknown
100 raise ValueError(
101 "unable to parse {} as a URL or as a local path".format(url_or_filename)
102 )
105def unzip_file(file: Union[str, Path], unzip_to: Union[str, Path]):
106 from zipfile import ZipFile
108 with ZipFile(Path(file), "r") as zipObj:
109 # Extract all the contents of zip file in current directory
110 zipObj.extractall(Path(unzip_to))
113def unpack_file(file: Path, unpack_to: Path, mode: str = None, keep: bool = True):
114 """
115 Unpacks a file to the given location.
117 :param file Archive file to unpack
118 :param unpack_to Destination where to store the output
119 :param mode Type of the archive (zip, tar, gz, targz, rar)
120 :param keep Indicates whether to keep the archive after extraction or delete it
121 """
122 if mode == "zip" or (mode is None and str(file).endswith("zip")):
123 from zipfile import ZipFile
125 with ZipFile(file, "r") as zipObj:
126 # Extract all the contents of zip file in current directory
127 zipObj.extractall(unpack_to)
129 elif mode == "targz" or (
130 mode is None and str(file).endswith("tar.gz") or str(file).endswith("tgz")
131 ):
132 import tarfile
134 with tarfile.open(file, "r:gz") as tarObj:
135 tarObj.extractall(unpack_to)
137 elif mode == "tar" or (mode is None and str(file).endswith("tar")):
138 import tarfile
140 with tarfile.open(file, "r") as tarObj:
141 tarObj.extractall(unpack_to)
143 elif mode == "gz" or (mode is None and str(file).endswith("gz")):
144 import gzip
146 with gzip.open(str(file), "rb") as f_in:
147 with open(str(unpack_to), "wb") as f_out:
148 shutil.copyfileobj(f_in, f_out)
150 elif mode == "rar" or (mode is None and str(file).endswith("rar")):
151 import patoolib
153 patoolib.extract_archive(str(file), outdir=unpack_to, interactive=False)
155 else:
156 if mode is None:
157 raise AssertionError(f"Can't infer archive type from {file}")
158 else:
159 raise AssertionError(f"Unsupported mode {mode}")
161 if not keep:
162 os.remove(str(file))
165def download_file(url: str, cache_dir: Union[str, Path]):
166 if type(cache_dir) is str:
167 cache_dir = Path(cache_dir)
168 cache_dir.mkdir(parents=True, exist_ok=True)
170 filename = re.sub(r".+/", "", url)
171 # get cache path to put the file
172 cache_path = cache_dir / filename
173 print(cache_path)
175 # Download to temporary file, then copy to cache dir once finished.
176 # Otherwise you get corrupt cache entries if the download gets interrupted.
177 fd, temp_filename = tempfile.mkstemp()
178 logger.info("%s not found in cache, downloading to %s", url, temp_filename)
180 # GET file object
181 req = requests.get(url, stream=True)
182 content_length = req.headers.get("Content-Length")
183 total = int(content_length) if content_length is not None else None
184 progress = Tqdm.tqdm(unit="B", total=total)
185 with open(temp_filename, "wb") as temp_file:
186 for chunk in req.iter_content(chunk_size=1024):
187 if chunk: # filter out keep-alive new chunks
188 progress.update(len(chunk))
189 temp_file.write(chunk)
191 progress.close()
193 logger.info("copying %s to cache at %s", temp_filename, cache_path)
194 shutil.copyfile(temp_filename, str(cache_path))
195 logger.info("removing temp file %s", temp_filename)
196 os.close(fd)
197 os.remove(temp_filename)
199 progress.close()
202# TODO(joelgrus): do we want to do checksums or anything like that?
203def get_from_cache(url: str, cache_dir: Path = None) -> Path:
204 """
205 Given a URL, look for the corresponding dataset in the local cache.
206 If it's not there, download it. Then return the path to the cached file.
207 """
208 cache_dir.mkdir(parents=True, exist_ok=True)
210 filename = re.sub(r".+/", "", url)
211 # get cache path to put the file
212 cache_path = cache_dir / filename
213 if cache_path.exists():
214 return cache_path
216 # make HEAD request to check ETag
217 response = requests.head(url, headers={"User-Agent": "Flair"}, allow_redirects=True)
218 if response.status_code != 200:
219 raise IOError(
220 f"HEAD request failed for url {url} with status code {response.status_code}."
221 )
223 # add ETag to filename if it exists
224 # etag = response.headers.get("ETag")
226 if not cache_path.exists():
227 # Download to temporary file, then copy to cache dir once finished.
228 # Otherwise you get corrupt cache entries if the download gets interrupted.
229 fd, temp_filename = tempfile.mkstemp()
230 logger.info("%s not found in cache, downloading to %s", url, temp_filename)
232 # GET file object
233 req = requests.get(url, stream=True, headers={"User-Agent": "Flair"})
234 content_length = req.headers.get("Content-Length")
235 total = int(content_length) if content_length is not None else None
236 progress = Tqdm.tqdm(unit="B", total=total)
237 with open(temp_filename, "wb") as temp_file:
238 for chunk in req.iter_content(chunk_size=1024):
239 if chunk: # filter out keep-alive new chunks
240 progress.update(len(chunk))
241 temp_file.write(chunk)
243 progress.close()
245 logger.info("copying %s to cache at %s", temp_filename, cache_path)
246 shutil.copyfile(temp_filename, str(cache_path))
247 logger.info("removing temp file %s", temp_filename)
248 os.close(fd)
249 os.remove(temp_filename)
251 return cache_path
254def open_inside_zip(
255 archive_path: str,
256 cache_dir: Union[str, Path],
257 member_path: Optional[str] = None,
258 encoding: str = "utf8",
259) -> iter:
260 cached_archive_path = cached_path(archive_path, cache_dir=Path(cache_dir))
261 archive = zipfile.ZipFile(cached_archive_path, "r")
262 if member_path is None:
263 members_list = archive.namelist()
264 member_path = get_the_only_file_in_the_archive(members_list, archive_path)
265 member_path = cast(str, member_path)
266 member_file = archive.open(member_path, "r")
267 return io.TextIOWrapper(member_file, encoding=encoding)
270def get_the_only_file_in_the_archive(
271 members_list: Sequence[str], archive_path: str
272) -> str:
273 if len(members_list) > 1:
274 raise ValueError(
275 "The archive %s contains multiple files, so you must select "
276 "one of the files inside providing a uri of the type: %s"
277 % (
278 archive_path,
279 format_embeddings_file_uri(
280 "path_or_url_to_archive", "path_inside_archive"
281 ),
282 )
283 )
284 return members_list[0]
287def format_embeddings_file_uri(
288 main_file_path_or_url: str, path_inside_archive: Optional[str] = None
289) -> str:
290 if path_inside_archive:
291 return "({})#{}".format(main_file_path_or_url, path_inside_archive)
292 return main_file_path_or_url
295from tqdm import tqdm as _tqdm
298class Tqdm:
299 # These defaults are the same as the argument defaults in tqdm.
300 default_mininterval: float = 0.1
302 @staticmethod
303 def set_default_mininterval(value: float) -> None:
304 Tqdm.default_mininterval = value
306 @staticmethod
307 def set_slower_interval(use_slower_interval: bool) -> None:
308 """
309 If ``use_slower_interval`` is ``True``, we will dramatically slow down ``tqdm's`` default
310 output rate. ``tqdm's`` default output rate is great for interactively watching progress,
311 but it is not great for log files. You might want to set this if you are primarily going
312 to be looking at output through log files, not the terminal.
313 """
314 if use_slower_interval:
315 Tqdm.default_mininterval = 10.0
316 else:
317 Tqdm.default_mininterval = 0.1
319 @staticmethod
320 def tqdm(*args, **kwargs):
321 new_kwargs = {"mininterval": Tqdm.default_mininterval, **kwargs}
323 return _tqdm(*args, **new_kwargs)
326def instance_lru_cache(*cache_args, **cache_kwargs):
327 def decorator(func):
328 @functools.wraps(func)
329 def create_cache(self, *args, **kwargs):
330 instance_cache = functools.lru_cache(*cache_args, **cache_kwargs)(func)
331 instance_cache = instance_cache.__get__(self, self.__class__)
332 setattr(self, func.__name__, instance_cache)
333 return instance_cache(*args, **kwargs)
335 return create_cache
337 return decorator