Coverage for flair/flair/embeddings/base.py: 57%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

63 statements  

1from abc import abstractmethod 

2from typing import Union, List, Dict 

3from torch.nn import ParameterList, Parameter 

4 

5import torch 

6import logging 

7 

8import flair 

9from flair.data import Sentence, Image 

10 

11log = logging.getLogger("flair") 

12 

13 

14class Embeddings(torch.nn.Module): 

15 """Abstract base class for all embeddings. Every new type of embedding must implement these methods.""" 

16 

17 def __init__(self): 

18 """Set some attributes that would otherwise result in errors. Overwrite these in your embedding class.""" 

19 if not hasattr(self, "name"): 

20 self.name: str = "unnamed_embedding" 

21 if not hasattr(self, "static_embeddings"): 

22 # if the embeddings for a sentence are the same in each epoch, set this to True for improved efficiency 

23 self.static_embeddings = False 

24 super().__init__() 

25 

26 @property 

27 @abstractmethod 

28 def embedding_length(self) -> int: 

29 """Returns the length of the embedding vector.""" 

30 raise NotImplementedError 

31 

32 @property 

33 @abstractmethod 

34 def embedding_type(self) -> str: 

35 raise NotImplementedError 

36 

37 def embed(self, sentences: Union[Sentence, List[Sentence]]) -> List[Sentence]: 

38 """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings 

39 are non-static.""" 

40 

41 # if only one sentence is passed, convert to list of sentence 

42 if (type(sentences) is Sentence) or (type(sentences) is Image): 

43 sentences = [sentences] 

44 

45 everything_embedded: bool = True 

46 

47 if self.embedding_type == "word-level": 

48 for sentence in sentences: 

49 for token in sentence.tokens: 

50 if self.name not in token._embeddings.keys(): 

51 everything_embedded = False 

52 break 

53 else: 

54 for sentence in sentences: 

55 if self.name not in sentence._embeddings.keys(): 

56 everything_embedded = False 

57 break 

58 

59 if not everything_embedded or not self.static_embeddings: 

60 self._add_embeddings_internal(sentences) 

61 

62 return sentences 

63 

64 @abstractmethod 

65 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

66 """Private method for adding embeddings to all words in a list of sentences.""" 

67 pass 

68 

69 def get_names(self) -> List[str]: 

70 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of 

71 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding). 

72 Then, the list contains the names of all embeddings in the stack.""" 

73 return [self.name] 

74 

75 def get_named_embeddings_dict(self) -> Dict: 

76 return {self.name: self} 

77 

78 

79class ScalarMix(torch.nn.Module): 

80 """ 

81 Computes a parameterised scalar mixture of N tensors. 

82 This method was proposed by Liu et al. (2019) in the paper: 

83 "Linguistic Knowledge and Transferability of Contextual Representations" (https://arxiv.org/abs/1903.08855) 

84 

85 The implementation is copied and slightly modified from the allennlp repository and is licensed under Apache 2.0. 

86 It can be found under: 

87 https://github.com/allenai/allennlp/blob/master/allennlp/modules/scalar_mix.py. 

88 """ 

89 

90 def __init__(self, mixture_size: int, trainable: bool = False) -> None: 

91 """ 

92 Inits scalar mix implementation. 

93 ``mixture = gamma * sum(s_k * tensor_k)`` where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters. 

94 :param mixture_size: size of mixtures (usually the number of layers) 

95 """ 

96 super(ScalarMix, self).__init__() 

97 self.mixture_size = mixture_size 

98 

99 initial_scalar_parameters = [0.0] * mixture_size 

100 

101 self.scalar_parameters = ParameterList( 

102 [ 

103 Parameter( 

104 torch.tensor( 

105 [initial_scalar_parameters[i]], 

106 dtype=torch.float, 

107 device=flair.device, 

108 ), 

109 requires_grad=trainable, 

110 

111 ) 

112 for i in range(mixture_size) 

113 ] 

114 ) 

115 self.gamma = Parameter( 

116 torch.tensor( 

117 [1.0], 

118 dtype=torch.float, 

119 device=flair.device, 

120 ), requires_grad=trainable 

121 ) 

122 

123 def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: 

124 """ 

125 Computes a weighted average of the ``tensors``. The input tensors an be any shape 

126 with at least two dimensions, but must all be the same shape. 

127 :param tensors: list of input tensors 

128 :return: computed weighted average of input tensors 

129 """ 

130 if len(tensors) != self.mixture_size: 

131 log.error( 

132 "{} tensors were passed, but the module was initialized to mix {} tensors.".format( 

133 len(tensors), self.mixture_size 

134 ) 

135 ) 

136 

137 normed_weights = torch.nn.functional.softmax( 

138 torch.cat([parameter for parameter in self.scalar_parameters]), dim=0 

139 ) 

140 normed_weights = torch.split(normed_weights, split_size_or_sections=1) 

141 

142 pieces = [] 

143 for weight, tensor in zip(normed_weights, tensors): 

144 pieces.append(weight * tensor) 

145 return self.gamma * sum(pieces)