Coverage for flair/flair/nn/dropout.py: 73%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

33 statements  

1import torch 

2 

3 

4class LockedDropout(torch.nn.Module): 

5 """ 

6 Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space. 

7 """ 

8 

9 def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False): 

10 super(LockedDropout, self).__init__() 

11 self.dropout_rate = dropout_rate 

12 self.batch_first = batch_first 

13 self.inplace = inplace 

14 

15 def forward(self, x): 

16 if not self.training or not self.dropout_rate: 

17 return x 

18 

19 if not self.batch_first: 

20 m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) 

21 else: 

22 m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) 

23 

24 mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) 

25 mask = mask.expand_as(x) 

26 return mask * x 

27 

28 def extra_repr(self): 

29 inplace_str = ", inplace" if self.inplace else "" 

30 return "p={}{}".format(self.dropout_rate, inplace_str) 

31 

32 

33class WordDropout(torch.nn.Module): 

34 """ 

35 Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space. 

36 """ 

37 

38 def __init__(self, dropout_rate=0.05, inplace=False): 

39 super(WordDropout, self).__init__() 

40 self.dropout_rate = dropout_rate 

41 self.inplace = inplace 

42 

43 def forward(self, x): 

44 if not self.training or not self.dropout_rate: 

45 return x 

46 

47 m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate) 

48 

49 mask = torch.autograd.Variable(m, requires_grad=False) 

50 return mask * x 

51 

52 def extra_repr(self): 

53 inplace_str = ", inplace" if self.inplace else "" 

54 return "p={}{}".format(self.dropout_rate, inplace_str)