Coverage for flair/flair/trainers/language_model_trainer.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

227 statements  

1import time, datetime 

2import random 

3import sys 

4from pathlib import Path 

5from typing import Union 

6 

7from torch import cuda 

8from torch.optim import AdamW 

9from torch.utils.data import Dataset, DataLoader 

10from torch.optim.sgd import SGD 

11 

12try: 

13 from apex import amp 

14except ImportError: 

15 amp = None 

16 

17import flair 

18from flair.data import Dictionary 

19from flair.models import LanguageModel 

20from flair.optim import * 

21from flair.training_utils import add_file_handler 

22 

23log = logging.getLogger("flair") 

24 

25 

26class TextDataset(Dataset): 

27 def __init__( 

28 self, 

29 path: Union[str, Path], 

30 dictionary: Dictionary, 

31 expand_vocab: bool = False, 

32 forward: bool = True, 

33 split_on_char: bool = True, 

34 random_case_flip: bool = True, 

35 document_delimiter: str = '\n', 

36 shuffle: bool = True, 

37 ): 

38 if type(path) is str: 

39 path = Path(path) 

40 assert path.exists() 

41 

42 self.files = None 

43 self.path = path 

44 self.dictionary = dictionary 

45 self.split_on_char = split_on_char 

46 self.forward = forward 

47 self.random_case_flip = random_case_flip 

48 self.expand_vocab = expand_vocab 

49 self.document_delimiter = document_delimiter 

50 self.shuffle = shuffle 

51 

52 if path.is_dir(): 

53 self.files = sorted([f for f in path.iterdir() if f.exists()]) 

54 else: 

55 self.files = [path] 

56 

57 def __len__(self): 

58 return len(self.files) 

59 

60 def __getitem__(self, index=0) -> torch.tensor: 

61 """Tokenizes a text file on character basis.""" 

62 if type(self.files[index]) is str: 

63 self.files[index] = Path(self.files[index]) 

64 assert self.files[index].exists() 

65 

66 with self.files[index].open("r", encoding="utf-8") as fin: 

67 lines = (doc + self.document_delimiter for doc in fin.read().split(self.document_delimiter) if doc) 

68 if self.random_case_flip: 

69 lines = map(self.random_casechange, lines) 

70 lines = list(map(list if self.split_on_char else str.split, lines)) 

71 

72 log.info(f"read text file with {len(lines)} lines") 

73 

74 if self.shuffle: 

75 random.shuffle(lines) 

76 log.info(f"shuffled") 

77 

78 if self.expand_vocab: 

79 for chars in lines: 

80 for char in chars: 

81 self.dictionary.add_item(char) 

82 

83 ids = torch.tensor( 

84 [self.dictionary.get_idx_for_item(char) for chars in lines for char in chars], 

85 dtype=torch.long 

86 ) 

87 if not self.forward: 

88 ids = ids.flip(0) 

89 return ids 

90 

91 @staticmethod 

92 def random_casechange(line: str) -> str: 

93 no = random.randint(0, 99) 

94 if no == 0: 

95 line = line.lower() 

96 if no == 1: 

97 line = line.upper() 

98 return line 

99 

100 

101class TextCorpus(object): 

102 def __init__( 

103 self, 

104 path: Union[Path, str], 

105 dictionary: Dictionary, 

106 forward: bool = True, 

107 character_level: bool = True, 

108 random_case_flip: bool = True, 

109 document_delimiter: str = '\n', 

110 ): 

111 self.dictionary: Dictionary = dictionary 

112 self.forward = forward 

113 self.split_on_char = character_level 

114 self.random_case_flip = random_case_flip 

115 self.document_delimiter: str = document_delimiter 

116 

117 if type(path) == str: 

118 path = Path(path) 

119 

120 self.train = TextDataset( 

121 path / "train", 

122 dictionary, 

123 False, 

124 self.forward, 

125 self.split_on_char, 

126 self.random_case_flip, 

127 document_delimiter=self.document_delimiter, 

128 shuffle=True, 

129 ) 

130 

131 # TextDataset returns a list. valid and test are only one file, so return the first element 

132 self.valid = TextDataset( 

133 path / "valid.txt", 

134 dictionary, 

135 False, 

136 self.forward, 

137 self.split_on_char, 

138 self.random_case_flip, 

139 document_delimiter=document_delimiter, 

140 shuffle=False, 

141 )[0] 

142 self.test = TextDataset( 

143 path / "test.txt", 

144 dictionary, 

145 False, 

146 self.forward, 

147 self.split_on_char, 

148 self.random_case_flip, 

149 document_delimiter=document_delimiter, 

150 shuffle=False, 

151 )[0] 

152 

153 

154class LanguageModelTrainer: 

155 def __init__( 

156 self, 

157 model: LanguageModel, 

158 corpus: TextCorpus, 

159 optimizer: Optimizer = SGD, 

160 test_mode: bool = False, 

161 epoch: int = 0, 

162 split: int = 0, 

163 loss: float = 10000, 

164 optimizer_state: dict = None, 

165 ): 

166 self.model: LanguageModel = model 

167 self.optimizer: Optimizer = optimizer 

168 self.corpus: TextCorpus = corpus 

169 self.test_mode: bool = test_mode 

170 

171 self.loss_function = torch.nn.CrossEntropyLoss() 

172 self.log_interval = 100 

173 self.epoch = epoch 

174 self.split = split 

175 self.loss = loss 

176 self.optimizer_state = optimizer_state 

177 

178 def train( 

179 self, 

180 base_path: Union[Path, str], 

181 sequence_length: int, 

182 learning_rate: float = 20, 

183 mini_batch_size: int = 100, 

184 anneal_factor: float = 0.25, 

185 patience: int = 10, 

186 clip=0.25, 

187 max_epochs: int = 1000, 

188 checkpoint: bool = False, 

189 grow_to_sequence_length: int = 0, 

190 num_workers: int = 2, 

191 use_amp: bool = False, 

192 amp_opt_level: str = "O1", 

193 **kwargs, 

194 ): 

195 

196 if use_amp: 

197 if sys.version_info < (3, 0): 

198 raise RuntimeError("Apex currently only supports Python 3. Aborting.") 

199 if amp is None: 

200 raise RuntimeError( 

201 "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " 

202 "to enable mixed-precision training." 

203 ) 

204 

205 # cast string to Path 

206 if type(base_path) is str: 

207 base_path = Path(base_path) 

208 

209 add_file_handler(log, base_path / "training.log") 

210 

211 number_of_splits: int = len(self.corpus.train) 

212 

213 val_data = self._batchify(self.corpus.valid, mini_batch_size) 

214 

215 # error message if the validation dataset is too small 

216 if val_data.size(0) == 1: 

217 raise RuntimeError( 

218 f"ERROR: Your validation dataset is too small. For your mini_batch_size, the data needs to " 

219 f"consist of at least {mini_batch_size * 2} characters!" 

220 ) 

221 

222 base_path.mkdir(parents=True, exist_ok=True) 

223 loss_txt = base_path / "loss.txt" 

224 savefile = base_path / "best-lm.pt" 

225 

226 try: 

227 best_val_loss = self.loss 

228 optimizer = self.optimizer( 

229 self.model.parameters(), lr=learning_rate, **kwargs 

230 ) 

231 if self.optimizer_state is not None: 

232 optimizer.load_state_dict(self.optimizer_state) 

233 

234 if isinstance(optimizer, (AdamW, SGDW)): 

235 scheduler: ReduceLRWDOnPlateau = ReduceLRWDOnPlateau( 

236 optimizer, verbose=True, factor=anneal_factor, patience=patience 

237 ) 

238 else: 

239 scheduler: ReduceLROnPlateau = ReduceLROnPlateau( 

240 optimizer, verbose=True, factor=anneal_factor, patience=patience 

241 ) 

242 

243 if use_amp: 

244 self.model, optimizer = amp.initialize( 

245 self.model, optimizer, opt_level=amp_opt_level 

246 ) 

247 

248 training_generator = DataLoader( 

249 self.corpus.train, shuffle=False, num_workers=num_workers 

250 ) 

251 

252 for epoch in range(self.epoch, max_epochs): 

253 epoch_start_time = time.time() 

254 # Shuffle training files randomly after serially iterating through corpus one 

255 if epoch > 0: 

256 training_generator = DataLoader( 

257 self.corpus.train, shuffle=True, num_workers=num_workers 

258 ) 

259 self.model.save_checkpoint( 

260 base_path / f"epoch_{epoch}.pt", 

261 optimizer, 

262 epoch, 

263 0, 

264 best_val_loss, 

265 ) 

266 

267 # iterate through training data, starting at self.split (for checkpointing) 

268 for curr_split, train_slice in enumerate( 

269 training_generator, self.split 

270 ): 

271 

272 if sequence_length < grow_to_sequence_length: 

273 sequence_length += 1 

274 log.info(f"Sequence length is {sequence_length}") 

275 

276 split_start_time = time.time() 

277 # off by one for printing 

278 curr_split += 1 

279 train_data = self._batchify(train_slice.flatten(), mini_batch_size) 

280 

281 log.info( 

282 "Split %d" % curr_split 

283 + "\t - ({:%H:%M:%S})".format(datetime.datetime.now()) 

284 ) 

285 

286 for group in optimizer.param_groups: 

287 learning_rate = group["lr"] 

288 

289 # go into train mode 

290 self.model.train() 

291 

292 # reset variables 

293 hidden = self.model.init_hidden(mini_batch_size) 

294 

295 # not really sure what this does 

296 ntokens = len(self.corpus.dictionary) 

297 

298 total_loss = 0 

299 start_time = time.time() 

300 

301 for batch, i in enumerate( 

302 range(0, train_data.size(0) - 1, sequence_length) 

303 ): 

304 data, targets = self._get_batch(train_data, i, sequence_length) 

305 

306 if not data.is_cuda and cuda.is_available(): 

307 log.info( 

308 "Batch %d is not on CUDA, training will be very slow" 

309 % (batch) 

310 ) 

311 raise Exception("data isnt on cuda") 

312 

313 self.model.zero_grad() 

314 optimizer.zero_grad() 

315 

316 # do the forward pass in the model 

317 output, rnn_output, hidden = self.model.forward(data, hidden) 

318 

319 # try to predict the targets 

320 loss = self.loss_function(output.view(-1, ntokens), targets) 

321 # Backward 

322 if use_amp: 

323 with amp.scale_loss(loss, optimizer) as scaled_loss: 

324 scaled_loss.backward() 

325 else: 

326 loss.backward() 

327 

328 # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 

329 torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip) 

330 

331 optimizer.step() 

332 

333 total_loss += loss.data 

334 

335 # We detach the hidden state from how it was previously produced. 

336 # If we didn't, the model would try backpropagating all the way to start of the dataset. 

337 hidden = self._repackage_hidden(hidden) 

338 

339 # explicitly remove loss to clear up memory 

340 del loss, output, rnn_output 

341 

342 if batch % self.log_interval == 0 and batch > 0: 

343 cur_loss = total_loss.item() / self.log_interval 

344 elapsed = time.time() - start_time 

345 log.info( 

346 "| split {:3d} /{:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | " 

347 "loss {:5.2f} | ppl {:8.2f}".format( 

348 curr_split, 

349 number_of_splits, 

350 batch, 

351 len(train_data) // sequence_length, 

352 elapsed * 1000 / self.log_interval, 

353 cur_loss, 

354 math.exp(cur_loss), 

355 ) 

356 ) 

357 total_loss = 0 

358 start_time = time.time() 

359 

360 log.info( 

361 "%d seconds for train split %d" 

362 % (time.time() - split_start_time, curr_split) 

363 ) 

364 

365 ############################################################################### 

366 self.model.eval() 

367 

368 val_loss = self.evaluate(val_data, mini_batch_size, sequence_length) 

369 scheduler.step(val_loss) 

370 

371 log.info("best loss so far {:5.2f}".format(best_val_loss)) 

372 

373 log.info(self.model.generate_text()) 

374 

375 if checkpoint: 

376 self.model.save_checkpoint( 

377 base_path / "checkpoint.pt", 

378 optimizer, 

379 epoch, 

380 curr_split, 

381 best_val_loss, 

382 ) 

383 

384 # Save the model if the validation loss is the best we've seen so far. 

385 if val_loss < best_val_loss: 

386 self.model.best_score = best_val_loss 

387 self.model.save(savefile) 

388 best_val_loss = val_loss 

389 

390 ############################################################################### 

391 # print info 

392 ############################################################################### 

393 log.info("-" * 89) 

394 

395 summary = ( 

396 "| end of split {:3d} /{:3d} | epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | " 

397 "valid ppl {:8.2f} | learning rate {:3.4f}".format( 

398 curr_split, 

399 number_of_splits, 

400 epoch + 1, 

401 (time.time() - split_start_time), 

402 val_loss, 

403 math.exp(val_loss), 

404 learning_rate, 

405 ) 

406 ) 

407 

408 with open(loss_txt, "a") as myfile: 

409 myfile.write("%s\n" % summary) 

410 

411 log.info(summary) 

412 log.info("-" * 89) 

413 

414 log.info("Epoch time: %.2f" % (time.time() - epoch_start_time)) 

415 

416 except KeyboardInterrupt: 

417 log.info("-" * 89) 

418 log.info("Exiting from training early") 

419 

420 ############################################################################### 

421 # final testing 

422 ############################################################################### 

423 test_data = self._batchify(self.corpus.test, mini_batch_size) 

424 test_loss = self.evaluate(test_data, mini_batch_size, sequence_length) 

425 

426 summary = "TEST: valid loss {:5.2f} | valid ppl {:8.2f}".format( 

427 test_loss, math.exp(test_loss) 

428 ) 

429 with open(loss_txt, "a") as myfile: 

430 myfile.write("%s\n" % summary) 

431 

432 log.info(summary) 

433 log.info("-" * 89) 

434 

435 def evaluate(self, data_source, eval_batch_size, sequence_length): 

436 # Turn on evaluation mode which disables dropout. 

437 self.model.eval() 

438 

439 with torch.no_grad(): 

440 total_loss = 0 

441 ntokens = len(self.corpus.dictionary) 

442 

443 hidden = self.model.init_hidden(eval_batch_size) 

444 

445 for i in range(0, data_source.size(0) - 1, sequence_length): 

446 data, targets = self._get_batch(data_source, i, sequence_length) 

447 prediction, rnn_output, hidden = self.model.forward(data, hidden) 

448 output_flat = prediction.view(-1, ntokens) 

449 total_loss += len(data) * self.loss_function(output_flat, targets).data 

450 hidden = self._repackage_hidden(hidden) 

451 return total_loss.item() / len(data_source) 

452 

453 @staticmethod 

454 def _batchify(data, batch_size): 

455 # Work out how cleanly we can divide the dataset into bsz parts. 

456 nbatch = data.size(0) // batch_size 

457 # Trim off any extra elements that wouldn't cleanly fit (remainders). 

458 data = data.narrow(0, 0, nbatch * batch_size) 

459 # Evenly divide the data across the bsz batches. 

460 data = data.view(batch_size, -1).t().contiguous() 

461 return data 

462 

463 @staticmethod 

464 def _get_batch(source, i, sequence_length): 

465 seq_len = min(sequence_length, len(source) - 1 - i) 

466 

467 data = source[i : i + seq_len].clone().detach() 

468 target = source[i + 1 : i + 1 + seq_len].view(-1).clone().detach() 

469 

470 data = data.to(flair.device) 

471 target = target.to(flair.device) 

472 

473 return data, target 

474 

475 @staticmethod 

476 def _repackage_hidden(h): 

477 """Wraps hidden states in new tensors, to detach them from their history.""" 

478 return tuple(v.clone().detach() for v in h) 

479 

480 @staticmethod 

481 def load_checkpoint( 

482 checkpoint_file: Union[str, Path], corpus: TextCorpus, optimizer: Optimizer = SGD 

483 ): 

484 if type(checkpoint_file) is str: 

485 checkpoint_file = Path(checkpoint_file) 

486 

487 checkpoint = LanguageModel.load_checkpoint(checkpoint_file) 

488 return LanguageModelTrainer( 

489 checkpoint["model"], 

490 corpus, 

491 optimizer, 

492 epoch=checkpoint["epoch"], 

493 split=checkpoint["split"], 

494 loss=checkpoint["loss"], 

495 optimizer_state=checkpoint["optimizer_state_dict"], 

496 )