Coverage for flair/flair/models/language_model.py: 50%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

214 statements  

1from pathlib import Path 

2 

3import torch.nn as nn 

4import torch 

5import math 

6from typing import Union, Tuple 

7from typing import List 

8 

9from torch.optim import Optimizer 

10 

11import flair 

12from flair.data import Dictionary 

13 

14 

15class LanguageModel(nn.Module): 

16 """Container module with an encoder, a recurrent module, and a decoder.""" 

17 

18 def __init__( 

19 self, 

20 dictionary: Dictionary, 

21 is_forward_lm: bool, 

22 hidden_size: int, 

23 nlayers: int, 

24 embedding_size: int = 100, 

25 nout=None, 

26 document_delimiter: str = '\n', 

27 dropout=0.1, 

28 ): 

29 

30 super(LanguageModel, self).__init__() 

31 

32 self.dictionary = dictionary 

33 self.document_delimiter = document_delimiter 

34 self.is_forward_lm: bool = is_forward_lm 

35 

36 self.dropout = dropout 

37 self.hidden_size = hidden_size 

38 self.embedding_size = embedding_size 

39 self.nlayers = nlayers 

40 

41 self.drop = nn.Dropout(dropout) 

42 self.encoder = nn.Embedding(len(dictionary), embedding_size) 

43 

44 if nlayers == 1: 

45 self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers) 

46 else: 

47 self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers, dropout=dropout) 

48 

49 self.hidden = None 

50 

51 self.nout = nout 

52 if nout is not None: 

53 self.proj = nn.Linear(hidden_size, nout) 

54 self.initialize(self.proj.weight) 

55 self.decoder = nn.Linear(nout, len(dictionary)) 

56 else: 

57 self.proj = None 

58 self.decoder = nn.Linear(hidden_size, len(dictionary)) 

59 

60 self.init_weights() 

61 

62 # auto-spawn on GPU if available 

63 self.to(flair.device) 

64 

65 def init_weights(self): 

66 initrange = 0.1 

67 self.encoder.weight.detach().uniform_(-initrange, initrange) 

68 self.decoder.bias.detach().fill_(0) 

69 self.decoder.weight.detach().uniform_(-initrange, initrange) 

70 

71 def set_hidden(self, hidden): 

72 self.hidden = hidden 

73 

74 def forward(self, input, hidden, ordered_sequence_lengths=None): 

75 encoded = self.encoder(input) 

76 emb = self.drop(encoded) 

77 

78 self.rnn.flatten_parameters() 

79 

80 output, hidden = self.rnn(emb, hidden) 

81 

82 if self.proj is not None: 

83 output = self.proj(output) 

84 

85 output = self.drop(output) 

86 

87 decoded = self.decoder( 

88 output.view(output.size(0) * output.size(1), output.size(2)) 

89 ) 

90 

91 return ( 

92 decoded.view(output.size(0), output.size(1), decoded.size(1)), 

93 output, 

94 hidden, 

95 ) 

96 

97 def init_hidden(self, bsz): 

98 weight = next(self.parameters()).detach() 

99 return ( 

100 weight.new(self.nlayers, bsz, self.hidden_size).zero_().clone().detach(), 

101 weight.new(self.nlayers, bsz, self.hidden_size).zero_().clone().detach(), 

102 ) 

103 

104 def get_representation( 

105 self, 

106 strings: List[str], 

107 start_marker: str, 

108 end_marker: str, 

109 chars_per_chunk: int = 512, 

110 ): 

111 

112 len_longest_str: int = len(max(strings, key=len)) 

113 

114 # pad strings with whitespaces to longest sentence 

115 padded_strings: List[str] = [] 

116 

117 for string in strings: 

118 if not self.is_forward_lm: 

119 string = string[::-1] 

120 

121 padded = f"{start_marker}{string}{end_marker}" 

122 padded_strings.append(padded) 

123 

124 # cut up the input into chunks of max charlength = chunk_size 

125 chunks = [] 

126 splice_begin = 0 

127 longest_padded_str: int = len_longest_str + len(start_marker) + len(end_marker) 

128 for splice_end in range(chars_per_chunk, longest_padded_str, chars_per_chunk): 

129 chunks.append([text[splice_begin:splice_end] for text in padded_strings]) 

130 splice_begin = splice_end 

131 

132 chunks.append( 

133 [text[splice_begin:longest_padded_str] for text in padded_strings] 

134 ) 

135 hidden = self.init_hidden(len(chunks[0])) 

136 

137 padding_char_index = self.dictionary.get_idx_for_item(" ") 

138 

139 batches: List[torch.Tensor] = [] 

140 # push each chunk through the RNN language model 

141 for chunk in chunks: 

142 len_longest_chunk: int = len(max(chunk, key=len)) 

143 sequences_as_char_indices: List[List[int]] = [] 

144 for string in chunk: 

145 char_indices = self.dictionary.get_idx_for_items(list(string)) 

146 char_indices += [padding_char_index] * (len_longest_chunk - len(string)) 

147 

148 sequences_as_char_indices.append(char_indices) 

149 t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to( 

150 device=flair.device, non_blocking=True 

151 ) 

152 batches.append(t) 

153 

154 output_parts = [] 

155 for batch in batches: 

156 batch = batch.transpose(0, 1) 

157 _, rnn_output, hidden = self.forward(batch, hidden) 

158 output_parts.append(rnn_output) 

159 

160 # concatenate all chunks to make final output 

161 output = torch.cat(output_parts) 

162 

163 return output 

164 

165 def get_output(self, text: str): 

166 char_indices = [self.dictionary.get_idx_for_item(char) for char in text] 

167 input_vector = torch.LongTensor([char_indices]).transpose(0, 1) 

168 

169 hidden = self.init_hidden(1) 

170 prediction, rnn_output, hidden = self.forward(input_vector, hidden) 

171 

172 return self.repackage_hidden(hidden) 

173 

174 def repackage_hidden(self, h): 

175 """Wraps hidden states in new Variables, to detach them from their history.""" 

176 if type(h) == torch.Tensor: 

177 return h.clone().detach() 

178 else: 

179 return tuple(self.repackage_hidden(v) for v in h) 

180 

181 @staticmethod 

182 def initialize(matrix): 

183 in_, out_ = matrix.size() 

184 stdv = math.sqrt(3.0 / (in_ + out_)) 

185 matrix.detach().uniform_(-stdv, stdv) 

186 

187 @classmethod 

188 def load_language_model(cls, model_file: Union[Path, str]): 

189 

190 state = torch.load(str(model_file), map_location=flair.device) 

191 

192 document_delimiter = state["document_delimiter"] if "document_delimiter" in state else '\n' 

193 

194 model = LanguageModel( 

195 dictionary=state["dictionary"], 

196 is_forward_lm=state["is_forward_lm"], 

197 hidden_size=state["hidden_size"], 

198 nlayers=state["nlayers"], 

199 embedding_size=state["embedding_size"], 

200 nout=state["nout"], 

201 document_delimiter=document_delimiter, 

202 dropout=state["dropout"], 

203 ) 

204 model.load_state_dict(state["state_dict"]) 

205 model.eval() 

206 model.to(flair.device) 

207 

208 return model 

209 

210 @classmethod 

211 def load_checkpoint(cls, model_file: Union[Path, str]): 

212 state = torch.load(str(model_file), map_location=flair.device) 

213 

214 epoch = state["epoch"] if "epoch" in state else None 

215 split = state["split"] if "split" in state else None 

216 loss = state["loss"] if "loss" in state else None 

217 document_delimiter = state["document_delimiter"] if "document_delimiter" in state else '\n' 

218 

219 optimizer_state_dict = ( 

220 state["optimizer_state_dict"] if "optimizer_state_dict" in state else None 

221 ) 

222 

223 model = LanguageModel( 

224 dictionary=state["dictionary"], 

225 is_forward_lm=state["is_forward_lm"], 

226 hidden_size=state["hidden_size"], 

227 nlayers=state["nlayers"], 

228 embedding_size=state["embedding_size"], 

229 nout=state["nout"], 

230 document_delimiter=document_delimiter, 

231 dropout=state["dropout"], 

232 ) 

233 model.load_state_dict(state["state_dict"]) 

234 model.eval() 

235 model.to(flair.device) 

236 

237 return { 

238 "model": model, 

239 "epoch": epoch, 

240 "split": split, 

241 "loss": loss, 

242 "optimizer_state_dict": optimizer_state_dict, 

243 } 

244 

245 def save_checkpoint( 

246 self, file: Union[Path, str], optimizer: Optimizer, epoch: int, split: int, loss: float 

247 ): 

248 model_state = { 

249 "state_dict": self.state_dict(), 

250 "dictionary": self.dictionary, 

251 "is_forward_lm": self.is_forward_lm, 

252 "hidden_size": self.hidden_size, 

253 "nlayers": self.nlayers, 

254 "embedding_size": self.embedding_size, 

255 "nout": self.nout, 

256 "document_delimiter": self.document_delimiter, 

257 "dropout": self.dropout, 

258 "optimizer_state_dict": optimizer.state_dict(), 

259 "epoch": epoch, 

260 "split": split, 

261 "loss": loss, 

262 } 

263 

264 torch.save(model_state, str(file), pickle_protocol=4) 

265 

266 def save(self, file: Union[Path, str]): 

267 model_state = { 

268 "state_dict": self.state_dict(), 

269 "dictionary": self.dictionary, 

270 "is_forward_lm": self.is_forward_lm, 

271 "hidden_size": self.hidden_size, 

272 "nlayers": self.nlayers, 

273 "embedding_size": self.embedding_size, 

274 "nout": self.nout, 

275 "document_delimiter": self.document_delimiter, 

276 "dropout": self.dropout, 

277 } 

278 

279 torch.save(model_state, str(file), pickle_protocol=4) 

280 

281 def generate_text( 

282 self, 

283 prefix: str = "\n", 

284 number_of_characters: int = 1000, 

285 temperature: float = 1.0, 

286 break_on_suffix=None, 

287 ) -> Tuple[str, float]: 

288 

289 if prefix == "": 

290 prefix = "\n" 

291 

292 with torch.no_grad(): 

293 characters = [] 

294 

295 idx2item = self.dictionary.idx2item 

296 

297 # initial hidden state 

298 hidden = self.init_hidden(1) 

299 

300 if len(prefix) > 1: 

301 

302 char_tensors = [] 

303 for character in prefix[:-1]: 

304 char_tensors.append( 

305 torch.tensor(self.dictionary.get_idx_for_item(character)) 

306 .unsqueeze(0) 

307 .unsqueeze(0) 

308 ) 

309 

310 input = torch.cat(char_tensors).to(flair.device) 

311 

312 prediction, _, hidden = self.forward(input, hidden) 

313 

314 input = ( 

315 torch.tensor(self.dictionary.get_idx_for_item(prefix[-1])) 

316 .unsqueeze(0) 

317 .unsqueeze(0) 

318 ) 

319 

320 log_prob = 0.0 

321 

322 for i in range(number_of_characters): 

323 

324 input = input.to(flair.device) 

325 

326 # get predicted weights 

327 prediction, _, hidden = self.forward(input, hidden) 

328 prediction = prediction.squeeze().detach() 

329 decoder_output = prediction 

330 

331 # divide by temperature 

332 prediction = prediction.div(temperature) 

333 

334 # to prevent overflow problem with small temperature values, substract largest value from all 

335 # this makes a vector in which the largest value is 0 

336 max = torch.max(prediction) 

337 prediction -= max 

338 

339 # compute word weights with exponential function 

340 word_weights = prediction.exp().cpu() 

341 

342 # try sampling multinomial distribution for next character 

343 try: 

344 word_idx = torch.multinomial(word_weights, 1)[0] 

345 except: 

346 word_idx = torch.tensor(0) 

347 

348 # print(word_idx) 

349 prob = decoder_output[word_idx] 

350 log_prob += prob 

351 

352 input = word_idx.detach().unsqueeze(0).unsqueeze(0) 

353 word = idx2item[word_idx].decode("UTF-8") 

354 characters.append(word) 

355 

356 if break_on_suffix is not None: 

357 if "".join(characters).endswith(break_on_suffix): 

358 break 

359 

360 text = prefix + "".join(characters) 

361 

362 log_prob = log_prob.item() 

363 log_prob /= len(characters) 

364 

365 if not self.is_forward_lm: 

366 text = text[::-1] 

367 

368 return text, log_prob 

369 

370 def calculate_perplexity(self, text: str) -> float: 

371 

372 if not self.is_forward_lm: 

373 text = text[::-1] 

374 

375 # input ids 

376 input = torch.tensor( 

377 [self.dictionary.get_idx_for_item(char) for char in text[:-1]] 

378 ).unsqueeze(1) 

379 input = input.to(flair.device) 

380 

381 # push list of character IDs through model 

382 hidden = self.init_hidden(1) 

383 prediction, _, hidden = self.forward(input, hidden) 

384 

385 # the target is always the next character 

386 targets = torch.tensor( 

387 [self.dictionary.get_idx_for_item(char) for char in text[1:]] 

388 ) 

389 targets = targets.to(flair.device) 

390 

391 # use cross entropy loss to compare output of forward pass with targets 

392 cross_entroy_loss = torch.nn.CrossEntropyLoss() 

393 loss = cross_entroy_loss( 

394 prediction.view(-1, len(self.dictionary)), targets 

395 ).item() 

396 

397 # exponentiate cross-entropy loss to calculate perplexity 

398 perplexity = math.exp(loss) 

399 

400 return perplexity 

401 

402 def __getstate__(self): 

403 

404 # serialize the language models and the constructor arguments (but nothing else) 

405 model_state = { 

406 "state_dict": self.state_dict(), 

407 

408 "dictionary": self.dictionary, 

409 "is_forward_lm": self.is_forward_lm, 

410 "hidden_size": self.hidden_size, 

411 "nlayers": self.nlayers, 

412 "embedding_size": self.embedding_size, 

413 "nout": self.nout, 

414 "document_delimiter": self.document_delimiter, 

415 "dropout": self.dropout, 

416 } 

417 

418 return model_state 

419 

420 def __setstate__(self, d): 

421 

422 # special handling for deserializing language models 

423 if "state_dict" in d: 

424 

425 # re-initialize language model with constructor arguments 

426 language_model = LanguageModel( 

427 dictionary=d['dictionary'], 

428 is_forward_lm=d['is_forward_lm'], 

429 hidden_size=d['hidden_size'], 

430 nlayers=d['nlayers'], 

431 embedding_size=d['embedding_size'], 

432 nout=d['nout'], 

433 document_delimiter=d['document_delimiter'], 

434 dropout=d['dropout'], 

435 ) 

436 

437 language_model.load_state_dict(d['state_dict']) 

438 

439 # copy over state dictionary to self 

440 for key in language_model.__dict__.keys(): 

441 self.__dict__[key] = language_model.__dict__[key] 

442 

443 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM 

444 # in their "self.train()" method) 

445 self.eval() 

446 

447 else: 

448 super().__setstate__(d) 

449 

450 def _apply(self, fn): 

451 

452 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute 

453 # check if this is the case and if so, set it 

454 for child_module in self.children(): 

455 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"): 

456 _flat_weights_names = [] 

457 

458 if child_module.__dict__["bidirectional"]: 

459 num_direction = 2 

460 else: 

461 num_direction = 1 

462 for layer in range(child_module.__dict__["num_layers"]): 

463 for direction in range(num_direction): 

464 suffix = "_reverse" if direction == 1 else "" 

465 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] 

466 if child_module.__dict__["bias"]: 

467 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] 

468 param_names = [ 

469 x.format(layer, suffix) for x in param_names 

470 ] 

471 _flat_weights_names.extend(param_names) 

472 

473 setattr(child_module, "_flat_weights_names", 

474 _flat_weights_names) 

475 

476 child_module._apply(fn)