Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/inference_utils.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import pickle
3import re
4import shutil
5import sqlite3
6from pathlib import Path
7import numpy as np
8import torch
9from tqdm import tqdm
11import flair
12from flair.embeddings import WordEmbeddings
14# this is the default init size of a lmdb database for embeddings
15DEFAULT_MAP_SIZE = 100 * 1024 * 1024 * 1024
17logger = logging.getLogger("flair")
20class WordEmbeddingsStore:
21 """
22 class to simulate a WordEmbeddings class from flair.
24 Run this to generate a headless (without word embeddings) model as well a stored word embeddings:
26 >>> from flair.inference_utils import WordEmbeddingsStore
27 >>> from flair.models import SequenceTagger
28 >>> import pickle
29 >>> tagger = SequenceTagger.load("multi-ner-fast")
30 >>> WordEmbeddingsStore.create_stores(tagger)
31 >>> pickle.dump(tagger, open("multi-ner-fast-headless.pickle", "wb"))
33 The same but using LMDB as memory database:
35 >>> from flair.inference_utils import WordEmbeddingsStore
36 >>> from flair.models import SequenceTagger
37 >>> import pickle
38 >>> tagger = SequenceTagger.load("multi-ner-fast")
39 >>> WordEmbeddingsStore.create_stores(tagger, backend='lmdb')
40 >>> pickle.dump(tagger, open("multi-ner-fast-headless.pickle", "wb"))
42 Then this can be used as follows:
44 >>> from flair.data import Sentence
45 >>> tagger = pickle.load(open("multi-ner-fast-headless.pickle", "rb"))
46 >>> WordEmbeddingsStore.load_stores(tagger)
47 >>> text = "Schade um den Ameisenbären. Lukas Bärfuss veröffentlicht Erzählungen aus zwanzig Jahren."
48 >>> sentence = Sentence(text)
49 >>> tagger.predict(sentence)
50 >>> print(sentence.get_spans('ner'))
52 The same but using LMDB as memory database:
54 >>> from flair.data import Sentence
55 >>> tagger = pickle.load(open("multi-ner-fast-headless.pickle", "rb"))
56 >>> WordEmbeddingsStore.load_stores(tagger, backend='lmdb')
57 >>> text = "Schade um den Ameisenbären. Lukas Bärfuss veröffentlicht Erzählungen aus zwanzig Jahren."
58 >>> sentence = Sentence(text)
59 >>> tagger.predict(sentence)
60 >>> print(sentence.get_spans('ner'))
61 """
63 def __init__(self, embedding: WordEmbeddings, backend='sqlite', verbose=True):
64 """
65 :param embedding: Flair WordEmbeddings instance.
66 :param backend: cache database backend name e.g ``'sqlite'``, ``'lmdb'``.
67 Default value is ``'sqlite'``.
68 :param verbose: If `True` print information on standard output
69 """
70 # some non-used parameter to allow print
71 self._modules = dict()
72 self.items = ""
74 # get db filename from embedding name
75 self.name = embedding.name
76 self.store_path: Path = WordEmbeddingsStore._get_store_path(embedding, backend)
77 if verbose:
78 logger.info(f"store filename: {str(self.store_path)}")
80 if backend == 'sqlite':
81 self.backend = SqliteWordEmbeddingsStoreBackend(embedding, verbose)
82 elif backend == 'lmdb':
83 self.backend = LmdbWordEmbeddingsStoreBackend(embedding, verbose)
84 else:
85 raise ValueError(
86 f'The given backend "{backend}" is not available.'
87 )
88 # In case initialization of cached version failed, just fallback to the original WordEmbeddings
89 if not self.backend.is_ok:
90 self.backend = WordEmbeddings(embedding.embeddings)
92 def _get_vector(self, word="house"):
93 return self.backend._get_vector(word)
95 def embed(self, sentences):
96 for sentence in sentences:
97 for token in sentence:
98 t = torch.tensor(self._get_vector(word=token.text.lower()))
99 token.set_embedding(self.name, t)
101 def get_names(self):
102 return [self.name]
104 @staticmethod
105 def _get_store_path(embedding, backend='sqlite'):
106 """
107 get the filename of the store
108 """
109 cache_dir = flair.cache_root
110 embedding_filename = re.findall("/(embeddings/.*)", embedding.name)[0]
111 store_path = cache_dir / (embedding_filename + "." + backend)
112 return store_path
114 @staticmethod
115 def _word_embeddings(model):
116 # SequenceTagger
117 if hasattr(model, 'embeddings'):
118 embeds = model.embeddings.embeddings
119 # TextClassifier
120 elif hasattr(model, 'document_embeddings') and hasattr(model.document_embeddings, 'embeddings'):
121 embeds = model.document_embeddings.embeddings.embeddings
122 else:
123 embeds = []
124 return embeds
126 @staticmethod
127 def create_stores(model, backend='sqlite'):
128 """
129 creates database versions of all word embeddings in the model and
130 deletes the original vectors to save memory
131 """
132 for embedding in WordEmbeddingsStore._word_embeddings(model):
133 if type(embedding) == WordEmbeddings:
134 WordEmbeddingsStore(embedding, backend)
135 del embedding.precomputed_word_embeddings
137 @staticmethod
138 def load_stores(model, backend='sqlite'):
139 """
140 loads the db versions of all word embeddings in the model
141 """
142 embeds = WordEmbeddingsStore._word_embeddings(model)
143 for i, embedding in enumerate(embeds):
144 if type(embedding) == WordEmbeddings:
145 embeds[i] = WordEmbeddingsStore(embedding, backend)
147 @staticmethod
148 def delete_stores(model, backend='sqlite'):
149 """
150 deletes the db versions of all word embeddings
151 """
152 for embedding in WordEmbeddingsStore._word_embeddings(model):
153 store_path: Path = WordEmbeddingsStore._get_store_path(embedding)
154 logger.info(f"delete store: {str(store_path)}")
155 if store_path.is_file():
156 store_path.unlink()
157 elif store_path.is_dir():
158 shutil.rmtree(store_path, ignore_errors=False, onerror=None)
161class WordEmbeddingsStoreBackend:
162 def __init__(self, embedding, backend, verbose=True):
163 # get db filename from embedding name
164 self.name = embedding.name
165 self.store_path: Path = WordEmbeddingsStore._get_store_path(embedding, backend)
167 @property
168 def is_ok(self):
169 return hasattr(self, 'k')
171 def _get_vector(self, word="house"):
172 pass
175class SqliteWordEmbeddingsStoreBackend(WordEmbeddingsStoreBackend):
176 def __init__(self, embedding, verbose):
177 super().__init__(embedding, 'sqlite', verbose)
178 # if embedding database already exists
179 if self.store_path.exists() and self.store_path.is_file():
180 try:
181 self.db = sqlite3.connect(str(self.store_path))
182 cursor = self.db.cursor()
183 cursor.execute("SELECT * FROM embedding LIMIT 1;")
184 result = list(cursor)
185 self.k = len(result[0]) - 1
186 return
187 except sqlite3.Error as err:
188 logger.exception(f"Fail to open sqlite database {str(self.store_path)}: {str(err)}")
189 # otherwise, push embedding to database
190 if hasattr(embedding, 'precomputed_word_embeddings'):
191 self.db = sqlite3.connect(str(self.store_path))
192 pwe = embedding.precomputed_word_embeddings
193 self.k = pwe.vector_size
194 self.db.execute(f"DROP TABLE IF EXISTS embedding;")
195 self.db.execute(
196 f"CREATE TABLE embedding(word text,{','.join('v' + str(i) + ' float' for i in range(self.k))});"
197 )
198 vectors_it = (
199 [word] + pwe.get_vector(word).tolist() for word in pwe.vocab.keys()
200 )
201 if verbose:
202 logger.info("load vectors to store")
203 self.db.executemany(
204 f"INSERT INTO embedding(word,{','.join('v' + str(i) for i in range(self.k))}) \
205 values ({','.join(['?'] * (1 + self.k))})",
206 tqdm(vectors_it),
207 )
208 self.db.execute(f"DROP INDEX IF EXISTS embedding_index;")
209 self.db.execute(f"CREATE INDEX embedding_index ON embedding(word);")
210 self.db.commit()
211 self.db.close()
213 def _get_vector(self, word="house"):
214 db = sqlite3.connect(str(self.store_path))
215 cursor = db.cursor()
216 word = word.replace('"', '')
217 cursor.execute(f'SELECT * FROM embedding WHERE word="{word}";')
218 result = list(cursor)
219 db.close()
220 if not result:
221 return [0.0] * self.k
222 return result[0][1:]
225class LmdbWordEmbeddingsStoreBackend(WordEmbeddingsStoreBackend):
226 def __init__(self, embedding, verbose):
227 super().__init__(embedding, 'lmdb', verbose)
228 try:
229 import lmdb
230 # if embedding database already exists
231 if self.store_path.exists() and self.store_path.is_dir():
232 # open the database in read mode
233 try:
234 self.env = lmdb.open(str(self.store_path), readonly=True, max_readers=2048, max_spare_txns=4)
235 if self.env:
236 # we need to set self.k
237 with self.env.begin() as txn:
238 cursor = txn.cursor()
239 for key, value in cursor:
240 vector = pickle.loads(value)
241 self.k = vector.shape[0]
242 break
243 cursor.close()
244 return
245 except lmdb.Error as err:
246 logger.exception(f"Fail to open lmdb database {str(self.store_path)}: {str(err)}")
247 # create and load the database in write mode
248 if hasattr(embedding, 'precomputed_word_embeddings'):
249 pwe = embedding.precomputed_word_embeddings
250 self.k = pwe.vector_size
251 self.store_path.mkdir(parents=True, exist_ok=True)
252 self.env = lmdb.open(str(self.store_path), map_size=DEFAULT_MAP_SIZE)
253 if verbose:
254 logger.info("load vectors to store")
256 txn = self.env.begin(write=True)
257 for word in tqdm(pwe.vocab.keys()):
258 vector = pwe.get_vector(word)
259 if len(word.encode(encoding='UTF-8')) < self.env.max_key_size():
260 txn.put(word.encode(encoding='UTF-8'), pickle.dumps(vector))
261 txn.commit()
262 return
263 except ModuleNotFoundError:
264 logger.warning("-" * 100)
265 logger.warning('ATTENTION! The library "lmdb" is not installed!')
266 logger.warning(
267 'To use LMDB, please first install with "pip install lmdb"'
268 )
269 logger.warning("-" * 100)
271 def _get_vector(self, word="house"):
272 try:
273 import lmdb
274 with self.env.begin() as txn:
275 vector = txn.get(word.encode(encoding='UTF-8'))
276 if vector:
277 word_vector = pickle.loads(vector)
278 vector = None
279 else:
280 word_vector = np.zeros((self.k,), dtype=np.float32)
281 except lmdb.Error:
282 # no idea why, but we need to close and reopen the environment to avoid
283 # mdb_txn_begin: MDB_BAD_RSLOT: Invalid reuse of reader locktable slot
284 # when opening new transaction !
285 self.env.close()
286 self.env = lmdb.open(self.store_path, readonly=True, max_readers=2048, max_spare_txns=2, lock=False)
287 return self._get_vector(word)
288 except ModuleNotFoundError:
289 logger.warning("-" * 100)
290 logger.warning('ATTENTION! The library "lmdb" is not installed!')
291 logger.warning(
292 'To use LMDB, please first install with "pip install lmdb"'
293 )
294 logger.warning("-" * 100)
295 word_vector = np.zeros((self.k,), dtype=np.float32)
296 return word_vector