Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/trainers/language_model_trainer.py: 16%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import time, datetime
2import random
3import sys
4from pathlib import Path
5from typing import Union
7from torch import cuda
8from torch.optim import AdamW
9from torch.utils.data import Dataset, DataLoader
10from torch.optim.sgd import SGD
12try:
13 from apex import amp
14except ImportError:
15 amp = None
17import flair
18from flair.data import Dictionary
19from flair.models import LanguageModel
20from flair.optim import *
21from flair.training_utils import add_file_handler
23log = logging.getLogger("flair")
26class TextDataset(Dataset):
27 def __init__(
28 self,
29 path: Union[str, Path],
30 dictionary: Dictionary,
31 expand_vocab: bool = False,
32 forward: bool = True,
33 split_on_char: bool = True,
34 random_case_flip: bool = True,
35 document_delimiter: str = '\n',
36 shuffle: bool = True,
37 ):
38 if type(path) is str:
39 path = Path(path)
40 assert path.exists()
42 self.files = None
43 self.path = path
44 self.dictionary = dictionary
45 self.split_on_char = split_on_char
46 self.forward = forward
47 self.random_case_flip = random_case_flip
48 self.expand_vocab = expand_vocab
49 self.document_delimiter = document_delimiter
50 self.shuffle = shuffle
52 if path.is_dir():
53 self.files = sorted([f for f in path.iterdir() if f.exists()])
54 else:
55 self.files = [path]
57 def __len__(self):
58 return len(self.files)
60 def __getitem__(self, index=0) -> torch.tensor:
61 """Tokenizes a text file on character basis."""
62 if type(self.files[index]) is str:
63 self.files[index] = Path(self.files[index])
64 assert self.files[index].exists()
66 with self.files[index].open("r", encoding="utf-8") as fin:
67 lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc)
68 if self.random_case_flip:
69 lines = map(self.random_casechange, lines)
70 lines = list(map(list if self.split_on_char else str.split, lines))
72 log.info(f"read text file with {len(lines)} lines")
74 if self.shuffle:
75 random.shuffle(lines)
76 log.info(f"shuffled")
78 if self.expand_vocab:
79 for chars in lines:
80 for char in chars:
81 self.dictionary.add_item(char)
83 ids = torch.tensor(
84 [self.dictionary.get_idx_for_item(char) for chars in lines for char in chars],
85 dtype=torch.long
86 )
87 if not self.forward:
88 ids = ids.flip(0)
89 return ids
91 @staticmethod
92 def random_casechange(line: str) -> str:
93 no = random.randint(0, 99)
94 if no == 0:
95 line = line.lower()
96 if no == 1:
97 line = line.upper()
98 return line
101class TextCorpus(object):
102 def __init__(
103 self,
104 path: Union[Path, str],
105 dictionary: Dictionary,
106 forward: bool = True,
107 character_level: bool = True,
108 random_case_flip: bool = True,
109 document_delimiter: str = '\n',
110 ):
111 self.dictionary: Dictionary = dictionary
112 self.forward = forward
113 self.split_on_char = character_level
114 self.random_case_flip = random_case_flip
115 self.document_delimiter: str = document_delimiter
117 if type(path) == str:
118 path = Path(path)
120 self.train = TextDataset(
121 path / "train",
122 dictionary,
123 False,
124 self.forward,
125 self.split_on_char,
126 self.random_case_flip,
127 document_delimiter=self.document_delimiter,
128 shuffle=True,
129 )
131 # TextDataset returns a list. valid and test are only one file, so return the first element
132 self.valid = TextDataset(
133 path / "valid.txt",
134 dictionary,
135 False,
136 self.forward,
137 self.split_on_char,
138 self.random_case_flip,
139 document_delimiter=document_delimiter,
140 shuffle=False,
141 )[0]
142 self.test = TextDataset(
143 path / "test.txt",
144 dictionary,
145 False,
146 self.forward,
147 self.split_on_char,
148 self.random_case_flip,
149 document_delimiter=document_delimiter,
150 shuffle=False,
151 )[0]
154class LanguageModelTrainer:
155 def __init__(
156 self,
157 model: LanguageModel,
158 corpus: TextCorpus,
159 optimizer: Optimizer = SGD,
160 test_mode: bool = False,
161 epoch: int = 0,
162 split: int = 0,
163 loss: float = 10000,
164 optimizer_state: dict = None,
165 ):
166 self.model: LanguageModel = model
167 self.optimizer: Optimizer = optimizer
168 self.corpus: TextCorpus = corpus
169 self.test_mode: bool = test_mode
171 self.loss_function = torch.nn.CrossEntropyLoss()
172 self.log_interval = 100
173 self.epoch = epoch
174 self.split = split
175 self.loss = loss
176 self.optimizer_state = optimizer_state
178 def train(
179 self,
180 base_path: Union[Path, str],
181 sequence_length: int,
182 learning_rate: float = 20,
183 mini_batch_size: int = 100,
184 anneal_factor: float = 0.25,
185 patience: int = 10,
186 clip=0.25,
187 max_epochs: int = 1000,
188 checkpoint: bool = False,
189 grow_to_sequence_length: int = 0,
190 num_workers: int = 2,
191 use_amp: bool = False,
192 amp_opt_level: str = "O1",
193 **kwargs,
194 ):
196 if use_amp:
197 if sys.version_info < (3, 0):
198 raise RuntimeError("Apex currently only supports Python 3. Aborting.")
199 if amp is None:
200 raise RuntimeError(
201 "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
202 "to enable mixed-precision training."
203 )
205 # cast string to Path
206 if type(base_path) is str:
207 base_path = Path(base_path)
209 add_file_handler(log, base_path / "training.log")
211 number_of_splits: int = len(self.corpus.train)
213 val_data = self._batchify(self.corpus.valid, mini_batch_size)
215 # error message if the validation dataset is too small
216 if val_data.size(0) == 1:
217 raise RuntimeError(
218 f"ERROR: Your validation dataset is too small. For your mini_batch_size, the data needs to "
219 f"consist of at least {mini_batch_size * 2} characters!"
220 )
222 base_path.mkdir(parents=True, exist_ok=True)
223 loss_txt = base_path / "loss.txt"
224 savefile = base_path / "best-lm.pt"
226 try:
227 best_val_loss = self.loss
228 optimizer = self.optimizer(
229 self.model.parameters(), lr=learning_rate, **kwargs
230 )
231 if self.optimizer_state is not None:
232 optimizer.load_state_dict(self.optimizer_state)
234 if isinstance(optimizer, (AdamW, SGDW)):
235 scheduler: ReduceLRWDOnPlateau = ReduceLRWDOnPlateau(
236 optimizer, verbose=True, factor=anneal_factor, patience=patience
237 )
238 else:
239 scheduler: ReduceLROnPlateau = ReduceLROnPlateau(
240 optimizer, verbose=True, factor=anneal_factor, patience=patience
241 )
243 if use_amp:
244 self.model, optimizer = amp.initialize(
245 self.model, optimizer, opt_level=amp_opt_level
246 )
248 training_generator = DataLoader(
249 self.corpus.train, shuffle=False, num_workers=num_workers
250 )
252 for epoch in range(self.epoch, max_epochs):
253 epoch_start_time = time.time()
254 # Shuffle training files randomly after serially iterating through corpus one
255 if epoch > 0:
256 training_generator = DataLoader(
257 self.corpus.train, shuffle=True, num_workers=num_workers
258 )
259 self.model.save_checkpoint(
260 base_path / f"epoch_{epoch}.pt",
261 optimizer,
262 epoch,
263 0,
264 best_val_loss,
265 )
267 # iterate through training data, starting at self.split (for checkpointing)
268 for curr_split, train_slice in enumerate(
269 training_generator, self.split
270 ):
272 if sequence_length < grow_to_sequence_length:
273 sequence_length += 1
274 log.info(f"Sequence length is {sequence_length}")
276 split_start_time = time.time()
277 # off by one for printing
278 curr_split += 1
279 train_data = self._batchify(train_slice.flatten(), mini_batch_size)
281 log.info(
282 "Split %d" % curr_split
283 + "\t - ({:%H:%M:%S})".format(datetime.datetime.now())
284 )
286 for group in optimizer.param_groups:
287 learning_rate = group["lr"]
289 # go into train mode
290 self.model.train()
292 # reset variables
293 hidden = self.model.init_hidden(mini_batch_size)
295 # not really sure what this does
296 ntokens = len(self.corpus.dictionary)
298 total_loss = 0
299 start_time = time.time()
301 for batch, i in enumerate(
302 range(0, train_data.size(0) - 1, sequence_length)
303 ):
304 data, targets = self._get_batch(train_data, i, sequence_length)
306 if not data.is_cuda and cuda.is_available():
307 log.info(
308 "Batch %d is not on CUDA, training will be very slow"
309 % (batch)
310 )
311 raise Exception("data isnt on cuda")
313 self.model.zero_grad()
314 optimizer.zero_grad()
316 # do the forward pass in the model
317 output, rnn_output, hidden = self.model.forward(data, hidden)
319 # try to predict the targets
320 loss = self.loss_function(output.view(-1, ntokens), targets)
321 # Backward
322 if use_amp:
323 with amp.scale_loss(loss, optimizer) as scaled_loss:
324 scaled_loss.backward()
325 else:
326 loss.backward()
328 # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
329 torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)
331 optimizer.step()
333 total_loss += loss.data
335 # We detach the hidden state from how it was previously produced.
336 # If we didn't, the model would try backpropagating all the way to start of the dataset.
337 hidden = self._repackage_hidden(hidden)
339 # explicitly remove loss to clear up memory
340 del loss, output, rnn_output
342 if batch % self.log_interval == 0 and batch > 0:
343 cur_loss = total_loss.item() / self.log_interval
344 elapsed = time.time() - start_time
345 log.info(
346 "| split {:3d} /{:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | "
347 "loss {:5.2f} | ppl {:8.2f}".format(
348 curr_split,
349 number_of_splits,
350 batch,
351 len(train_data) // sequence_length,
352 elapsed * 1000 / self.log_interval,
353 cur_loss,
354 math.exp(cur_loss),
355 )
356 )
357 total_loss = 0
358 start_time = time.time()
360 log.info(
361 "%d seconds for train split %d"
362 % (time.time() - split_start_time, curr_split)
363 )
365 ###############################################################################
366 self.model.eval()
368 val_loss = self.evaluate(val_data, mini_batch_size, sequence_length)
369 scheduler.step(val_loss)
371 log.info("best loss so far {:5.2f}".format(best_val_loss))
373 log.info(self.model.generate_text())
375 if checkpoint:
376 self.model.save_checkpoint(
377 base_path / "checkpoint.pt",
378 optimizer,
379 epoch,
380 curr_split,
381 best_val_loss,
382 )
384 # Save the model if the validation loss is the best we've seen so far.
385 if val_loss < best_val_loss:
386 self.model.best_score = best_val_loss
387 self.model.save(savefile)
388 best_val_loss = val_loss
390 ###############################################################################
391 # print info
392 ###############################################################################
393 log.info("-" * 89)
395 summary = (
396 "| end of split {:3d} /{:3d} | epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | "
397 "valid ppl {:8.2f} | learning rate {:3.4f}".format(
398 curr_split,
399 number_of_splits,
400 epoch + 1,
401 (time.time() - split_start_time),
402 val_loss,
403 math.exp(val_loss),
404 learning_rate,
405 )
406 )
408 with open(loss_txt, "a") as myfile:
409 myfile.write("%s\n" % summary)
411 log.info(summary)
412 log.info("-" * 89)
414 log.info("Epoch time: %.2f" % (time.time() - epoch_start_time))
416 except KeyboardInterrupt:
417 log.info("-" * 89)
418 log.info("Exiting from training early")
420 ###############################################################################
421 # final testing
422 ###############################################################################
423 test_data = self._batchify(self.corpus.test, mini_batch_size)
424 test_loss = self.evaluate(test_data, mini_batch_size, sequence_length)
426 summary = "TEST: valid loss {:5.2f} | valid ppl {:8.2f}".format(
427 test_loss, math.exp(test_loss)
428 )
429 with open(loss_txt, "a") as myfile:
430 myfile.write("%s\n" % summary)
432 log.info(summary)
433 log.info("-" * 89)
435 def evaluate(self, data_source, eval_batch_size, sequence_length):
436 # Turn on evaluation mode which disables dropout.
437 self.model.eval()
439 with torch.no_grad():
440 total_loss = 0
441 ntokens = len(self.corpus.dictionary)
443 hidden = self.model.init_hidden(eval_batch_size)
445 for i in range(0, data_source.size(0) - 1, sequence_length):
446 data, targets = self._get_batch(data_source, i, sequence_length)
447 prediction, rnn_output, hidden = self.model.forward(data, hidden)
448 output_flat = prediction.view(-1, ntokens)
449 total_loss += len(data) * self.loss_function(output_flat, targets).data
450 hidden = self._repackage_hidden(hidden)
451 return total_loss.item() / len(data_source)
453 @staticmethod
454 def _batchify(data, batch_size):
455 # Work out how cleanly we can divide the dataset into bsz parts.
456 nbatch = data.size(0) // batch_size
457 # Trim off any extra elements that wouldn't cleanly fit (remainders).
458 data = data.narrow(0, 0, nbatch * batch_size)
459 # Evenly divide the data across the bsz batches.
460 data = data.view(batch_size, -1).t().contiguous()
461 return data
463 @staticmethod
464 def _get_batch(source, i, sequence_length):
465 seq_len = min(sequence_length, len(source) - 1 - i)
467 data = source[i : i + seq_len].clone().detach()
468 target = source[i + 1 : i + 1 + seq_len].view(-1).clone().detach()
470 data = data.to(flair.device)
471 target = target.to(flair.device)
473 return data, target
475 @staticmethod
476 def _repackage_hidden(h):
477 """Wraps hidden states in new tensors, to detach them from their history."""
478 return tuple(v.clone().detach() for v in h)
480 @staticmethod
481 def load_checkpoint(
482 checkpoint_file: Union[str, Path], corpus: TextCorpus, optimizer: Optimizer = SGD
483 ):
484 if type(checkpoint_file) is str:
485 checkpoint_file = Path(checkpoint_file)
487 checkpoint = LanguageModel.load_checkpoint(checkpoint_file)
488 return LanguageModelTrainer(
489 checkpoint["model"],
490 corpus,
491 optimizer,
492 epoch=checkpoint["epoch"],
493 split=checkpoint["split"],
494 loss=checkpoint["loss"],
495 optimizer_state=checkpoint["optimizer_state_dict"],
496 )