Coverage for flair/flair/models/language_model.py: 70%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from pathlib import Path
3import torch.nn as nn
4import torch
5import math
6from typing import Union, Tuple
7from typing import List
9from torch.optim import Optimizer
11import flair
12from flair.data import Dictionary
15class LanguageModel(nn.Module):
16 """Container module with an encoder, a recurrent module, and a decoder."""
18 def __init__(
19 self,
20 dictionary: Dictionary,
21 is_forward_lm: bool,
22 hidden_size: int,
23 nlayers: int,
24 embedding_size: int = 100,
25 nout=None,
26 document_delimiter: str = '\n',
27 dropout=0.1,
28 ):
30 super(LanguageModel, self).__init__()
32 self.dictionary = dictionary
33 self.document_delimiter = document_delimiter
34 self.is_forward_lm: bool = is_forward_lm
36 self.dropout = dropout
37 self.hidden_size = hidden_size
38 self.embedding_size = embedding_size
39 self.nlayers = nlayers
41 self.drop = nn.Dropout(dropout)
42 self.encoder = nn.Embedding(len(dictionary), embedding_size)
44 if nlayers == 1:
45 self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers)
46 else:
47 self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers, dropout=dropout)
49 self.hidden = None
51 self.nout = nout
52 if nout is not None:
53 self.proj = nn.Linear(hidden_size, nout)
54 self.initialize(self.proj.weight)
55 self.decoder = nn.Linear(nout, len(dictionary))
56 else:
57 self.proj = None
58 self.decoder = nn.Linear(hidden_size, len(dictionary))
60 self.init_weights()
62 # auto-spawn on GPU if available
63 self.to(flair.device)
65 def init_weights(self):
66 initrange = 0.1
67 self.encoder.weight.detach().uniform_(-initrange, initrange)
68 self.decoder.bias.detach().fill_(0)
69 self.decoder.weight.detach().uniform_(-initrange, initrange)
71 def set_hidden(self, hidden):
72 self.hidden = hidden
74 def forward(self, input, hidden, ordered_sequence_lengths=None):
75 encoded = self.encoder(input)
76 emb = self.drop(encoded)
78 self.rnn.flatten_parameters()
80 output, hidden = self.rnn(emb, hidden)
82 if self.proj is not None:
83 output = self.proj(output)
85 output = self.drop(output)
87 decoded = self.decoder(
88 output.view(output.size(0) * output.size(1), output.size(2))
89 )
91 return (
92 decoded.view(output.size(0), output.size(1), decoded.size(1)),
93 output,
94 hidden,
95 )
97 def init_hidden(self, bsz):
98 weight = next(self.parameters()).detach()
99 return (
100 weight.new(self.nlayers, bsz, self.hidden_size).zero_().clone().detach(),
101 weight.new(self.nlayers, bsz, self.hidden_size).zero_().clone().detach(),
102 )
104 def get_representation(
105 self,
106 strings: List[str],
107 start_marker: str,
108 end_marker: str,
109 chars_per_chunk: int = 512,
110 ):
112 len_longest_str: int = len(max(strings, key=len))
114 # pad strings with whitespaces to longest sentence
115 padded_strings: List[str] = []
117 for string in strings:
118 if not self.is_forward_lm:
119 string = string[::-1]
121 padded = f"{start_marker}{string}{end_marker}"
122 padded_strings.append(padded)
124 # cut up the input into chunks of max charlength = chunk_size
125 chunks = []
126 splice_begin = 0
127 longest_padded_str: int = len_longest_str + len(start_marker) + len(end_marker)
128 for splice_end in range(chars_per_chunk, longest_padded_str, chars_per_chunk):
129 chunks.append([text[splice_begin:splice_end] for text in padded_strings])
130 splice_begin = splice_end
132 chunks.append(
133 [text[splice_begin:longest_padded_str] for text in padded_strings]
134 )
135 hidden = self.init_hidden(len(chunks[0]))
137 padding_char_index = self.dictionary.get_idx_for_item(" ")
139 batches: List[torch.Tensor] = []
140 # push each chunk through the RNN language model
141 for chunk in chunks:
142 len_longest_chunk: int = len(max(chunk, key=len))
143 sequences_as_char_indices: List[List[int]] = []
144 for string in chunk:
145 char_indices = self.dictionary.get_idx_for_items(list(string))
146 char_indices += [padding_char_index] * (len_longest_chunk - len(string))
148 sequences_as_char_indices.append(char_indices)
149 t = torch.tensor(sequences_as_char_indices, dtype=torch.long).to(
150 device=flair.device, non_blocking=True
151 )
152 batches.append(t)
154 output_parts = []
155 for batch in batches:
156 batch = batch.transpose(0, 1)
157 _, rnn_output, hidden = self.forward(batch, hidden)
158 output_parts.append(rnn_output)
160 # concatenate all chunks to make final output
161 output = torch.cat(output_parts)
163 return output
165 def get_output(self, text: str):
166 char_indices = [self.dictionary.get_idx_for_item(char) for char in text]
167 input_vector = torch.LongTensor([char_indices]).transpose(0, 1)
169 hidden = self.init_hidden(1)
170 prediction, rnn_output, hidden = self.forward(input_vector, hidden)
172 return self.repackage_hidden(hidden)
174 def repackage_hidden(self, h):
175 """Wraps hidden states in new Variables, to detach them from their history."""
176 if type(h) == torch.Tensor:
177 return h.clone().detach()
178 else:
179 return tuple(self.repackage_hidden(v) for v in h)
181 @staticmethod
182 def initialize(matrix):
183 in_, out_ = matrix.size()
184 stdv = math.sqrt(3.0 / (in_ + out_))
185 matrix.detach().uniform_(-stdv, stdv)
187 @classmethod
188 def load_language_model(cls, model_file: Union[Path, str]):
190 state = torch.load(str(model_file), map_location=flair.device)
192 document_delimiter = state["document_delimiter"] if "document_delimiter" in state else '\n'
194 model = LanguageModel(
195 dictionary=state["dictionary"],
196 is_forward_lm=state["is_forward_lm"],
197 hidden_size=state["hidden_size"],
198 nlayers=state["nlayers"],
199 embedding_size=state["embedding_size"],
200 nout=state["nout"],
201 document_delimiter=document_delimiter,
202 dropout=state["dropout"],
203 )
204 model.load_state_dict(state["state_dict"])
205 model.eval()
206 model.to(flair.device)
208 return model
210 @classmethod
211 def load_checkpoint(cls, model_file: Union[Path, str]):
212 state = torch.load(str(model_file), map_location=flair.device)
214 epoch = state["epoch"] if "epoch" in state else None
215 split = state["split"] if "split" in state else None
216 loss = state["loss"] if "loss" in state else None
217 document_delimiter = state["document_delimiter"] if "document_delimiter" in state else '\n'
219 optimizer_state_dict = (
220 state["optimizer_state_dict"] if "optimizer_state_dict" in state else None
221 )
223 model = LanguageModel(
224 dictionary=state["dictionary"],
225 is_forward_lm=state["is_forward_lm"],
226 hidden_size=state["hidden_size"],
227 nlayers=state["nlayers"],
228 embedding_size=state["embedding_size"],
229 nout=state["nout"],
230 document_delimiter=document_delimiter,
231 dropout=state["dropout"],
232 )
233 model.load_state_dict(state["state_dict"])
234 model.eval()
235 model.to(flair.device)
237 return {
238 "model": model,
239 "epoch": epoch,
240 "split": split,
241 "loss": loss,
242 "optimizer_state_dict": optimizer_state_dict,
243 }
245 def save_checkpoint(
246 self, file: Union[Path, str], optimizer: Optimizer, epoch: int, split: int, loss: float
247 ):
248 model_state = {
249 "state_dict": self.state_dict(),
250 "dictionary": self.dictionary,
251 "is_forward_lm": self.is_forward_lm,
252 "hidden_size": self.hidden_size,
253 "nlayers": self.nlayers,
254 "embedding_size": self.embedding_size,
255 "nout": self.nout,
256 "document_delimiter": self.document_delimiter,
257 "dropout": self.dropout,
258 "optimizer_state_dict": optimizer.state_dict(),
259 "epoch": epoch,
260 "split": split,
261 "loss": loss,
262 }
264 torch.save(model_state, str(file), pickle_protocol=4)
266 def save(self, file: Union[Path, str]):
267 model_state = {
268 "state_dict": self.state_dict(),
269 "dictionary": self.dictionary,
270 "is_forward_lm": self.is_forward_lm,
271 "hidden_size": self.hidden_size,
272 "nlayers": self.nlayers,
273 "embedding_size": self.embedding_size,
274 "nout": self.nout,
275 "document_delimiter": self.document_delimiter,
276 "dropout": self.dropout,
277 }
279 torch.save(model_state, str(file), pickle_protocol=4)
281 def generate_text(
282 self,
283 prefix: str = "\n",
284 number_of_characters: int = 1000,
285 temperature: float = 1.0,
286 break_on_suffix=None,
287 ) -> Tuple[str, float]:
289 if prefix == "":
290 prefix = "\n"
292 with torch.no_grad():
293 characters = []
295 idx2item = self.dictionary.idx2item
297 # initial hidden state
298 hidden = self.init_hidden(1)
300 if len(prefix) > 1:
302 char_tensors = []
303 for character in prefix[:-1]:
304 char_tensors.append(
305 torch.tensor(self.dictionary.get_idx_for_item(character))
306 .unsqueeze(0)
307 .unsqueeze(0)
308 )
310 input = torch.cat(char_tensors).to(flair.device)
312 prediction, _, hidden = self.forward(input, hidden)
314 input = (
315 torch.tensor(self.dictionary.get_idx_for_item(prefix[-1]))
316 .unsqueeze(0)
317 .unsqueeze(0)
318 )
320 log_prob = 0.0
322 for i in range(number_of_characters):
324 input = input.to(flair.device)
326 # get predicted weights
327 prediction, _, hidden = self.forward(input, hidden)
328 prediction = prediction.squeeze().detach()
329 decoder_output = prediction
331 # divide by temperature
332 prediction = prediction.div(temperature)
334 # to prevent overflow problem with small temperature values, substract largest value from all
335 # this makes a vector in which the largest value is 0
336 max = torch.max(prediction)
337 prediction -= max
339 # compute word weights with exponential function
340 word_weights = prediction.exp().cpu()
342 # try sampling multinomial distribution for next character
343 try:
344 word_idx = torch.multinomial(word_weights, 1)[0]
345 except:
346 word_idx = torch.tensor(0)
348 # print(word_idx)
349 prob = decoder_output[word_idx]
350 log_prob += prob
352 input = word_idx.detach().unsqueeze(0).unsqueeze(0)
353 word = idx2item[word_idx].decode("UTF-8")
354 characters.append(word)
356 if break_on_suffix is not None:
357 if "".join(characters).endswith(break_on_suffix):
358 break
360 text = prefix + "".join(characters)
362 log_prob = log_prob.item()
363 log_prob /= len(characters)
365 if not self.is_forward_lm:
366 text = text[::-1]
368 return text, log_prob
370 def calculate_perplexity(self, text: str) -> float:
372 if not self.is_forward_lm:
373 text = text[::-1]
375 # input ids
376 input = torch.tensor(
377 [self.dictionary.get_idx_for_item(char) for char in text[:-1]]
378 ).unsqueeze(1)
379 input = input.to(flair.device)
381 # push list of character IDs through model
382 hidden = self.init_hidden(1)
383 prediction, _, hidden = self.forward(input, hidden)
385 # the target is always the next character
386 targets = torch.tensor(
387 [self.dictionary.get_idx_for_item(char) for char in text[1:]]
388 )
389 targets = targets.to(flair.device)
391 # use cross entropy loss to compare output of forward pass with targets
392 cross_entroy_loss = torch.nn.CrossEntropyLoss()
393 loss = cross_entroy_loss(
394 prediction.view(-1, len(self.dictionary)), targets
395 ).item()
397 # exponentiate cross-entropy loss to calculate perplexity
398 perplexity = math.exp(loss)
400 return perplexity
402 def __getstate__(self):
404 # serialize the language models and the constructor arguments (but nothing else)
405 model_state = {
406 "state_dict": self.state_dict(),
408 "dictionary": self.dictionary,
409 "is_forward_lm": self.is_forward_lm,
410 "hidden_size": self.hidden_size,
411 "nlayers": self.nlayers,
412 "embedding_size": self.embedding_size,
413 "nout": self.nout,
414 "document_delimiter": self.document_delimiter,
415 "dropout": self.dropout,
416 }
418 return model_state
420 def __setstate__(self, d):
422 # special handling for deserializing language models
423 if "state_dict" in d:
425 # re-initialize language model with constructor arguments
426 language_model = LanguageModel(
427 dictionary=d['dictionary'],
428 is_forward_lm=d['is_forward_lm'],
429 hidden_size=d['hidden_size'],
430 nlayers=d['nlayers'],
431 embedding_size=d['embedding_size'],
432 nout=d['nout'],
433 document_delimiter=d['document_delimiter'],
434 dropout=d['dropout'],
435 )
437 language_model.load_state_dict(d['state_dict'])
439 # copy over state dictionary to self
440 for key in language_model.__dict__.keys():
441 self.__dict__[key] = language_model.__dict__[key]
443 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM
444 # in their "self.train()" method)
445 self.eval()
447 else:
448 super().__setstate__(d)
450 def _apply(self, fn):
452 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute
453 # check if this is the case and if so, set it
454 for child_module in self.children():
455 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"):
456 _flat_weights_names = []
458 if child_module.__dict__["bidirectional"]:
459 num_direction = 2
460 else:
461 num_direction = 1
462 for layer in range(child_module.__dict__["num_layers"]):
463 for direction in range(num_direction):
464 suffix = "_reverse" if direction == 1 else ""
465 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
466 if child_module.__dict__["bias"]:
467 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
468 param_names = [
469 x.format(layer, suffix) for x in param_names
470 ]
471 _flat_weights_names.extend(param_names)
473 setattr(child_module, "_flat_weights_names",
474 _flat_weights_names)
476 child_module._apply(fn)