Coverage for flair/flair/file_utils.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

175 statements  

1""" 

2Utilities for working with the local dataset cache. Copied from AllenNLP 

3""" 

4from pathlib import Path 

5from typing import Tuple, Union, Optional, Sequence, cast 

6import os 

7import base64 

8import logging 

9import shutil 

10import tempfile 

11import re 

12import functools 

13from urllib.parse import urlparse 

14 

15import mmap 

16import requests 

17import zipfile 

18import io 

19 

20# from allennlp.common.tqdm import Tqdm 

21import flair 

22 

23logger = logging.getLogger("flair") 

24 

25 

26def load_big_file(f: str) -> mmap.mmap: 

27 """ 

28 Workaround for loading a big pickle file. Files over 2GB cause pickle errors on certin Mac and Windows distributions. 

29 :param f: 

30 :return: 

31 """ 

32 logger.info(f"loading file {f}") 

33 with open(f, "rb") as f_in: 

34 # mmap seems to be much more memory efficient 

35 bf = mmap.mmap(f_in.fileno(), 0, access=mmap.ACCESS_READ) 

36 f_in.close() 

37 return bf 

38 

39 

40def url_to_filename(url: str, etag: str = None) -> str: 

41 """ 

42 Converts a url into a filename in a reversible way. 

43 If `etag` is specified, add it on the end, separated by a period 

44 (which necessarily won't appear in the base64-encoded filename). 

45 Get rid of the quotes in the etag, since Windows doesn't like them. 

46 """ 

47 url_bytes = url.encode("utf-8") 

48 b64_bytes = base64.b64encode(url_bytes) 

49 decoded = b64_bytes.decode("utf-8") 

50 

51 if etag: 

52 # Remove quotes from etag 

53 etag = etag.replace('"', "") 

54 return f"{decoded}.{etag}" 

55 else: 

56 return decoded 

57 

58 

59def filename_to_url(filename: str) -> Tuple[str, str]: 

60 """ 

61 Recovers the the url from the encoded filename. Returns it and the ETag 

62 (which may be ``None``) 

63 """ 

64 try: 

65 # If there is an etag, it's everything after the first period 

66 decoded, etag = filename.split(".", 1) 

67 except ValueError: 

68 # Otherwise, use None 

69 decoded, etag = filename, None 

70 

71 filename_bytes = decoded.encode("utf-8") 

72 url_bytes = base64.b64decode(filename_bytes) 

73 return url_bytes.decode("utf-8"), etag 

74 

75 

76def cached_path(url_or_filename: str, cache_dir: Union[str, Path]) -> Path: 

77 """ 

78 Given something that might be a URL (or might be a local path), 

79 determine which. If it's a URL, download the file and cache it, and 

80 return the path to the cached file. If it's already a local path, 

81 make sure the file exists and then return the path. 

82 """ 

83 if type(cache_dir) is str: 

84 cache_dir = Path(cache_dir) 

85 dataset_cache = flair.cache_root / cache_dir 

86 

87 parsed = urlparse(url_or_filename) 

88 

89 if parsed.scheme in ("http", "https"): 

90 # URL, so get it from the cache (downloading if necessary) 

91 return get_from_cache(url_or_filename, dataset_cache) 

92 elif parsed.scheme == "" and Path(url_or_filename).exists(): 

93 # File, and it exists. 

94 return Path(url_or_filename) 

95 elif parsed.scheme == "": 

96 # File, but it doesn't exist. 

97 raise FileNotFoundError("file {} not found".format(url_or_filename)) 

98 else: 

99 # Something unknown 

100 raise ValueError( 

101 "unable to parse {} as a URL or as a local path".format(url_or_filename) 

102 ) 

103 

104 

105def unzip_file(file: Union[str, Path], unzip_to: Union[str, Path]): 

106 from zipfile import ZipFile 

107 

108 with ZipFile(Path(file), "r") as zipObj: 

109 # Extract all the contents of zip file in current directory 

110 zipObj.extractall(Path(unzip_to)) 

111 

112 

113def unpack_file(file: Path, unpack_to: Path, mode: str = None, keep: bool = True): 

114 """ 

115 Unpacks a file to the given location. 

116 

117 :param file Archive file to unpack 

118 :param unpack_to Destination where to store the output 

119 :param mode Type of the archive (zip, tar, gz, targz, rar) 

120 :param keep Indicates whether to keep the archive after extraction or delete it 

121 """ 

122 if mode == "zip" or (mode is None and str(file).endswith("zip")): 

123 from zipfile import ZipFile 

124 

125 with ZipFile(file, "r") as zipObj: 

126 # Extract all the contents of zip file in current directory 

127 zipObj.extractall(unpack_to) 

128 

129 elif mode == "targz" or ( 

130 mode is None and str(file).endswith("tar.gz") or str(file).endswith("tgz") 

131 ): 

132 import tarfile 

133 

134 with tarfile.open(file, "r:gz") as tarObj: 

135 tarObj.extractall(unpack_to) 

136 

137 elif mode == "tar" or (mode is None and str(file).endswith("tar")): 

138 import tarfile 

139 

140 with tarfile.open(file, "r") as tarObj: 

141 tarObj.extractall(unpack_to) 

142 

143 elif mode == "gz" or (mode is None and str(file).endswith("gz")): 

144 import gzip 

145 

146 with gzip.open(str(file), "rb") as f_in: 

147 with open(str(unpack_to), "wb") as f_out: 

148 shutil.copyfileobj(f_in, f_out) 

149 

150 elif mode == "rar" or (mode is None and str(file).endswith("rar")): 

151 import patoolib 

152 

153 patoolib.extract_archive(str(file), outdir=unpack_to, interactive=False) 

154 

155 else: 

156 if mode is None: 

157 raise AssertionError(f"Can't infer archive type from {file}") 

158 else: 

159 raise AssertionError(f"Unsupported mode {mode}") 

160 

161 if not keep: 

162 os.remove(str(file)) 

163 

164 

165def download_file(url: str, cache_dir: Union[str, Path]): 

166 if type(cache_dir) is str: 

167 cache_dir = Path(cache_dir) 

168 cache_dir.mkdir(parents=True, exist_ok=True) 

169 

170 filename = re.sub(r".+/", "", url) 

171 # get cache path to put the file 

172 cache_path = cache_dir / filename 

173 print(cache_path) 

174 

175 # Download to temporary file, then copy to cache dir once finished. 

176 # Otherwise you get corrupt cache entries if the download gets interrupted. 

177 fd, temp_filename = tempfile.mkstemp() 

178 logger.info("%s not found in cache, downloading to %s", url, temp_filename) 

179 

180 # GET file object 

181 req = requests.get(url, stream=True) 

182 content_length = req.headers.get("Content-Length") 

183 total = int(content_length) if content_length is not None else None 

184 progress = Tqdm.tqdm(unit="B", total=total) 

185 with open(temp_filename, "wb") as temp_file: 

186 for chunk in req.iter_content(chunk_size=1024): 

187 if chunk: # filter out keep-alive new chunks 

188 progress.update(len(chunk)) 

189 temp_file.write(chunk) 

190 

191 progress.close() 

192 

193 logger.info("copying %s to cache at %s", temp_filename, cache_path) 

194 shutil.copyfile(temp_filename, str(cache_path)) 

195 logger.info("removing temp file %s", temp_filename) 

196 os.close(fd) 

197 os.remove(temp_filename) 

198 

199 progress.close() 

200 

201 

202# TODO(joelgrus): do we want to do checksums or anything like that? 

203def get_from_cache(url: str, cache_dir: Path = None) -> Path: 

204 """ 

205 Given a URL, look for the corresponding dataset in the local cache. 

206 If it's not there, download it. Then return the path to the cached file. 

207 """ 

208 cache_dir.mkdir(parents=True, exist_ok=True) 

209 

210 filename = re.sub(r".+/", "", url) 

211 # get cache path to put the file 

212 cache_path = cache_dir / filename 

213 if cache_path.exists(): 

214 return cache_path 

215 

216 # make HEAD request to check ETag 

217 response = requests.head(url, headers={"User-Agent": "Flair"}, allow_redirects=True) 

218 if response.status_code != 200: 

219 raise IOError( 

220 f"HEAD request failed for url {url} with status code {response.status_code}." 

221 ) 

222 

223 # add ETag to filename if it exists 

224 # etag = response.headers.get("ETag") 

225 

226 if not cache_path.exists(): 

227 # Download to temporary file, then copy to cache dir once finished. 

228 # Otherwise you get corrupt cache entries if the download gets interrupted. 

229 fd, temp_filename = tempfile.mkstemp() 

230 logger.info("%s not found in cache, downloading to %s", url, temp_filename) 

231 

232 # GET file object 

233 req = requests.get(url, stream=True, headers={"User-Agent": "Flair"}) 

234 content_length = req.headers.get("Content-Length") 

235 total = int(content_length) if content_length is not None else None 

236 progress = Tqdm.tqdm(unit="B", total=total) 

237 with open(temp_filename, "wb") as temp_file: 

238 for chunk in req.iter_content(chunk_size=1024): 

239 if chunk: # filter out keep-alive new chunks 

240 progress.update(len(chunk)) 

241 temp_file.write(chunk) 

242 

243 progress.close() 

244 

245 logger.info("copying %s to cache at %s", temp_filename, cache_path) 

246 shutil.copyfile(temp_filename, str(cache_path)) 

247 logger.info("removing temp file %s", temp_filename) 

248 os.close(fd) 

249 os.remove(temp_filename) 

250 

251 return cache_path 

252 

253 

254def open_inside_zip( 

255 archive_path: str, 

256 cache_dir: Union[str, Path], 

257 member_path: Optional[str] = None, 

258 encoding: str = "utf8", 

259) -> iter: 

260 cached_archive_path = cached_path(archive_path, cache_dir=Path(cache_dir)) 

261 archive = zipfile.ZipFile(cached_archive_path, "r") 

262 if member_path is None: 

263 members_list = archive.namelist() 

264 member_path = get_the_only_file_in_the_archive(members_list, archive_path) 

265 member_path = cast(str, member_path) 

266 member_file = archive.open(member_path, "r") 

267 return io.TextIOWrapper(member_file, encoding=encoding) 

268 

269 

270def get_the_only_file_in_the_archive( 

271 members_list: Sequence[str], archive_path: str 

272) -> str: 

273 if len(members_list) > 1: 

274 raise ValueError( 

275 "The archive %s contains multiple files, so you must select " 

276 "one of the files inside providing a uri of the type: %s" 

277 % ( 

278 archive_path, 

279 format_embeddings_file_uri( 

280 "path_or_url_to_archive", "path_inside_archive" 

281 ), 

282 ) 

283 ) 

284 return members_list[0] 

285 

286 

287def format_embeddings_file_uri( 

288 main_file_path_or_url: str, path_inside_archive: Optional[str] = None 

289) -> str: 

290 if path_inside_archive: 

291 return "({})#{}".format(main_file_path_or_url, path_inside_archive) 

292 return main_file_path_or_url 

293 

294 

295from tqdm import tqdm as _tqdm 

296 

297 

298class Tqdm: 

299 # These defaults are the same as the argument defaults in tqdm. 

300 default_mininterval: float = 0.1 

301 

302 @staticmethod 

303 def set_default_mininterval(value: float) -> None: 

304 Tqdm.default_mininterval = value 

305 

306 @staticmethod 

307 def set_slower_interval(use_slower_interval: bool) -> None: 

308 """ 

309 If ``use_slower_interval`` is ``True``, we will dramatically slow down ``tqdm's`` default 

310 output rate. ``tqdm's`` default output rate is great for interactively watching progress, 

311 but it is not great for log files. You might want to set this if you are primarily going 

312 to be looking at output through log files, not the terminal. 

313 """ 

314 if use_slower_interval: 

315 Tqdm.default_mininterval = 10.0 

316 else: 

317 Tqdm.default_mininterval = 0.1 

318 

319 @staticmethod 

320 def tqdm(*args, **kwargs): 

321 new_kwargs = {"mininterval": Tqdm.default_mininterval, **kwargs} 

322 

323 return _tqdm(*args, **new_kwargs) 

324 

325 

326def instance_lru_cache(*cache_args, **cache_kwargs): 

327 def decorator(func): 

328 @functools.wraps(func) 

329 def create_cache(self, *args, **kwargs): 

330 instance_cache = functools.lru_cache(*cache_args, **cache_kwargs)(func) 

331 instance_cache = instance_cache.__get__(self, self.__class__) 

332 setattr(self, func.__name__, instance_cache) 

333 return instance_cache(*args, **kwargs) 

334 

335 return create_cache 

336 

337 return decorator