Coverage for flair/flair/nn/dropout.py: 73%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
4class LockedDropout(torch.nn.Module):
5 """
6 Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space.
7 """
9 def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False):
10 super(LockedDropout, self).__init__()
11 self.dropout_rate = dropout_rate
12 self.batch_first = batch_first
13 self.inplace = inplace
15 def forward(self, x):
16 if not self.training or not self.dropout_rate:
17 return x
19 if not self.batch_first:
20 m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
21 else:
22 m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate)
24 mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
25 mask = mask.expand_as(x)
26 return mask * x
28 def extra_repr(self):
29 inplace_str = ", inplace" if self.inplace else ""
30 return "p={}{}".format(self.dropout_rate, inplace_str)
33class WordDropout(torch.nn.Module):
34 """
35 Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space.
36 """
38 def __init__(self, dropout_rate=0.05, inplace=False):
39 super(WordDropout, self).__init__()
40 self.dropout_rate = dropout_rate
41 self.inplace = inplace
43 def forward(self, x):
44 if not self.training or not self.dropout_rate:
45 return x
47 m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate)
49 mask = torch.autograd.Variable(m, requires_grad=False)
50 return mask * x
52 def extra_repr(self):
53 inplace_str = ", inplace" if self.inplace else ""
54 return "p={}{}".format(self.dropout_rate, inplace_str)