Coverage for flair/flair/inference_utils.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

176 statements  

1import logging 

2import pickle 

3import re 

4import shutil 

5import sqlite3 

6from pathlib import Path 

7import numpy as np 

8import torch 

9from tqdm import tqdm 

10 

11import flair 

12from flair.embeddings import WordEmbeddings 

13 

14# this is the default init size of a lmdb database for embeddings 

15DEFAULT_MAP_SIZE = 100 * 1024 * 1024 * 1024 

16 

17logger = logging.getLogger("flair") 

18 

19 

20class WordEmbeddingsStore: 

21 """ 

22 class to simulate a WordEmbeddings class from flair. 

23 

24 Run this to generate a headless (without word embeddings) model as well a stored word embeddings: 

25 

26 >>> from flair.inference_utils import WordEmbeddingsStore 

27 >>> from flair.models import SequenceTagger 

28 >>> import pickle 

29 >>> tagger = SequenceTagger.load("multi-ner-fast") 

30 >>> WordEmbeddingsStore.create_stores(tagger) 

31 >>> pickle.dump(tagger, open("multi-ner-fast-headless.pickle", "wb")) 

32 

33 The same but using LMDB as memory database: 

34 

35 >>> from flair.inference_utils import WordEmbeddingsStore 

36 >>> from flair.models import SequenceTagger 

37 >>> import pickle 

38 >>> tagger = SequenceTagger.load("multi-ner-fast") 

39 >>> WordEmbeddingsStore.create_stores(tagger, backend='lmdb') 

40 >>> pickle.dump(tagger, open("multi-ner-fast-headless.pickle", "wb")) 

41 

42 Then this can be used as follows: 

43 

44 >>> from flair.data import Sentence 

45 >>> tagger = pickle.load(open("multi-ner-fast-headless.pickle", "rb")) 

46 >>> WordEmbeddingsStore.load_stores(tagger) 

47 >>> text = "Schade um den Ameisenbären. Lukas Bärfuss veröffentlicht Erzählungen aus zwanzig Jahren." 

48 >>> sentence = Sentence(text) 

49 >>> tagger.predict(sentence) 

50 >>> print(sentence.get_spans('ner')) 

51 

52 The same but using LMDB as memory database: 

53 

54 >>> from flair.data import Sentence 

55 >>> tagger = pickle.load(open("multi-ner-fast-headless.pickle", "rb")) 

56 >>> WordEmbeddingsStore.load_stores(tagger, backend='lmdb') 

57 >>> text = "Schade um den Ameisenbären. Lukas Bärfuss veröffentlicht Erzählungen aus zwanzig Jahren." 

58 >>> sentence = Sentence(text) 

59 >>> tagger.predict(sentence) 

60 >>> print(sentence.get_spans('ner')) 

61 """ 

62 

63 def __init__(self, embedding: WordEmbeddings, backend='sqlite', verbose=True): 

64 """ 

65 :param embedding: Flair WordEmbeddings instance. 

66 :param backend: cache database backend name e.g ``'sqlite'``, ``'lmdb'``. 

67 Default value is ``'sqlite'``. 

68 :param verbose: If `True` print information on standard output 

69 """ 

70 # some non-used parameter to allow print 

71 self._modules = dict() 

72 self.items = "" 

73 

74 # get db filename from embedding name 

75 self.name = embedding.name 

76 self.store_path: Path = WordEmbeddingsStore._get_store_path(embedding, backend) 

77 if verbose: 

78 logger.info(f"store filename: {str(self.store_path)}") 

79 

80 if backend == 'sqlite': 

81 self.backend = SqliteWordEmbeddingsStoreBackend(embedding, verbose) 

82 elif backend == 'lmdb': 

83 self.backend = LmdbWordEmbeddingsStoreBackend(embedding, verbose) 

84 else: 

85 raise ValueError( 

86 f'The given backend "{backend}" is not available.' 

87 ) 

88 # In case initialization of cached version failed, just fallback to the original WordEmbeddings 

89 if not self.backend.is_ok: 

90 self.backend = WordEmbeddings(embedding.embeddings) 

91 

92 def _get_vector(self, word="house"): 

93 return self.backend._get_vector(word) 

94 

95 def embed(self, sentences): 

96 for sentence in sentences: 

97 for token in sentence: 

98 t = torch.tensor(self._get_vector(word=token.text.lower())) 

99 token.set_embedding(self.name, t) 

100 

101 def get_names(self): 

102 return [self.name] 

103 

104 @staticmethod 

105 def _get_store_path(embedding, backend='sqlite'): 

106 """ 

107 get the filename of the store 

108 """ 

109 cache_dir = flair.cache_root 

110 embedding_filename = re.findall("/(embeddings/.*)", embedding.name)[0] 

111 store_path = cache_dir / (embedding_filename + "." + backend) 

112 return store_path 

113 

114 @staticmethod 

115 def _word_embeddings(model): 

116 # SequenceTagger 

117 if hasattr(model, 'embeddings'): 

118 embeds = model.embeddings.embeddings 

119 # TextClassifier 

120 elif hasattr(model, 'document_embeddings') and hasattr(model.document_embeddings, 'embeddings'): 

121 embeds = model.document_embeddings.embeddings.embeddings 

122 else: 

123 embeds = [] 

124 return embeds 

125 

126 @staticmethod 

127 def create_stores(model, backend='sqlite'): 

128 """ 

129 creates database versions of all word embeddings in the model and 

130 deletes the original vectors to save memory 

131 """ 

132 for embedding in WordEmbeddingsStore._word_embeddings(model): 

133 if type(embedding) == WordEmbeddings: 

134 WordEmbeddingsStore(embedding, backend) 

135 del embedding.precomputed_word_embeddings 

136 

137 @staticmethod 

138 def load_stores(model, backend='sqlite'): 

139 """ 

140 loads the db versions of all word embeddings in the model 

141 """ 

142 embeds = WordEmbeddingsStore._word_embeddings(model) 

143 for i, embedding in enumerate(embeds): 

144 if type(embedding) == WordEmbeddings: 

145 embeds[i] = WordEmbeddingsStore(embedding, backend) 

146 

147 @staticmethod 

148 def delete_stores(model, backend='sqlite'): 

149 """ 

150 deletes the db versions of all word embeddings 

151 """ 

152 for embedding in WordEmbeddingsStore._word_embeddings(model): 

153 store_path: Path = WordEmbeddingsStore._get_store_path(embedding) 

154 logger.info(f"delete store: {str(store_path)}") 

155 if store_path.is_file(): 

156 store_path.unlink() 

157 elif store_path.is_dir(): 

158 shutil.rmtree(store_path, ignore_errors=False, onerror=None) 

159 

160 

161class WordEmbeddingsStoreBackend: 

162 def __init__(self, embedding, backend, verbose=True): 

163 # get db filename from embedding name 

164 self.name = embedding.name 

165 self.store_path: Path = WordEmbeddingsStore._get_store_path(embedding, backend) 

166 

167 @property 

168 def is_ok(self): 

169 return hasattr(self, 'k') 

170 

171 def _get_vector(self, word="house"): 

172 pass 

173 

174 

175class SqliteWordEmbeddingsStoreBackend(WordEmbeddingsStoreBackend): 

176 def __init__(self, embedding, verbose): 

177 super().__init__(embedding, 'sqlite', verbose) 

178 # if embedding database already exists 

179 if self.store_path.exists() and self.store_path.is_file(): 

180 try: 

181 self.db = sqlite3.connect(str(self.store_path)) 

182 cursor = self.db.cursor() 

183 cursor.execute("SELECT * FROM embedding LIMIT 1;") 

184 result = list(cursor) 

185 self.k = len(result[0]) - 1 

186 return 

187 except sqlite3.Error as err: 

188 logger.exception(f"Fail to open sqlite database {str(self.store_path)}: {str(err)}") 

189 # otherwise, push embedding to database 

190 if hasattr(embedding, 'precomputed_word_embeddings'): 

191 self.db = sqlite3.connect(str(self.store_path)) 

192 pwe = embedding.precomputed_word_embeddings 

193 self.k = pwe.vector_size 

194 self.db.execute(f"DROP TABLE IF EXISTS embedding;") 

195 self.db.execute( 

196 f"CREATE TABLE embedding(word text,{','.join('v' + str(i) + ' float' for i in range(self.k))});" 

197 ) 

198 vectors_it = ( 

199 [word] + pwe.get_vector(word).tolist() for word in pwe.vocab.keys() 

200 ) 

201 if verbose: 

202 logger.info("load vectors to store") 

203 self.db.executemany( 

204 f"INSERT INTO embedding(word,{','.join('v' + str(i) for i in range(self.k))}) \ 

205 values ({','.join(['?'] * (1 + self.k))})", 

206 tqdm(vectors_it), 

207 ) 

208 self.db.execute(f"DROP INDEX IF EXISTS embedding_index;") 

209 self.db.execute(f"CREATE INDEX embedding_index ON embedding(word);") 

210 self.db.commit() 

211 self.db.close() 

212 

213 def _get_vector(self, word="house"): 

214 db = sqlite3.connect(str(self.store_path)) 

215 cursor = db.cursor() 

216 word = word.replace('"', '') 

217 cursor.execute(f'SELECT * FROM embedding WHERE word="{word}";') 

218 result = list(cursor) 

219 db.close() 

220 if not result: 

221 return [0.0] * self.k 

222 return result[0][1:] 

223 

224 

225class LmdbWordEmbeddingsStoreBackend(WordEmbeddingsStoreBackend): 

226 def __init__(self, embedding, verbose): 

227 super().__init__(embedding, 'lmdb', verbose) 

228 try: 

229 import lmdb 

230 # if embedding database already exists 

231 if self.store_path.exists() and self.store_path.is_dir(): 

232 # open the database in read mode 

233 try: 

234 self.env = lmdb.open(str(self.store_path), readonly=True, max_readers=2048, max_spare_txns=4) 

235 if self.env: 

236 # we need to set self.k 

237 with self.env.begin() as txn: 

238 cursor = txn.cursor() 

239 for key, value in cursor: 

240 vector = pickle.loads(value) 

241 self.k = vector.shape[0] 

242 break 

243 cursor.close() 

244 return 

245 except lmdb.Error as err: 

246 logger.exception(f"Fail to open lmdb database {str(self.store_path)}: {str(err)}") 

247 # create and load the database in write mode 

248 if hasattr(embedding, 'precomputed_word_embeddings'): 

249 pwe = embedding.precomputed_word_embeddings 

250 self.k = pwe.vector_size 

251 self.store_path.mkdir(parents=True, exist_ok=True) 

252 self.env = lmdb.open(str(self.store_path), map_size=DEFAULT_MAP_SIZE) 

253 if verbose: 

254 logger.info("load vectors to store") 

255 

256 txn = self.env.begin(write=True) 

257 for word in tqdm(pwe.vocab.keys()): 

258 vector = pwe.get_vector(word) 

259 if len(word.encode(encoding='UTF-8')) < self.env.max_key_size(): 

260 txn.put(word.encode(encoding='UTF-8'), pickle.dumps(vector)) 

261 txn.commit() 

262 return 

263 except ModuleNotFoundError: 

264 logger.warning("-" * 100) 

265 logger.warning('ATTENTION! The library "lmdb" is not installed!') 

266 logger.warning( 

267 'To use LMDB, please first install with "pip install lmdb"' 

268 ) 

269 logger.warning("-" * 100) 

270 

271 def _get_vector(self, word="house"): 

272 try: 

273 import lmdb 

274 with self.env.begin() as txn: 

275 vector = txn.get(word.encode(encoding='UTF-8')) 

276 if vector: 

277 word_vector = pickle.loads(vector) 

278 vector = None 

279 else: 

280 word_vector = np.zeros((self.k,), dtype=np.float32) 

281 except lmdb.Error: 

282 # no idea why, but we need to close and reopen the environment to avoid 

283 # mdb_txn_begin: MDB_BAD_RSLOT: Invalid reuse of reader locktable slot 

284 # when opening new transaction ! 

285 self.env.close() 

286 self.env = lmdb.open(self.store_path, readonly=True, max_readers=2048, max_spare_txns=2, lock=False) 

287 return self._get_vector(word) 

288 except ModuleNotFoundError: 

289 logger.warning("-" * 100) 

290 logger.warning('ATTENTION! The library "lmdb" is not installed!') 

291 logger.warning( 

292 'To use LMDB, please first install with "pip install lmdb"' 

293 ) 

294 logger.warning("-" * 100) 

295 word_vector = np.zeros((self.k,), dtype=np.float32) 

296 return word_vector