Coverage for flair/flair/embeddings/base.py: 71%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from abc import abstractmethod
2from typing import Union, List, Dict
3from torch.nn import ParameterList, Parameter
5import torch
6import logging
8import flair
9from flair.data import Sentence, Image
11log = logging.getLogger("flair")
14class Embeddings(torch.nn.Module):
15 """Abstract base class for all embeddings. Every new type of embedding must implement these methods."""
17 def __init__(self):
18 """Set some attributes that would otherwise result in errors. Overwrite these in your embedding class."""
19 if not hasattr(self, "name"):
20 self.name: str = "unnamed_embedding"
21 if not hasattr(self, "static_embeddings"):
22 # if the embeddings for a sentence are the same in each epoch, set this to True for improved efficiency
23 self.static_embeddings = False
24 super().__init__()
26 @property
27 @abstractmethod
28 def embedding_length(self) -> int:
29 """Returns the length of the embedding vector."""
30 raise NotImplementedError
32 @property
33 @abstractmethod
34 def embedding_type(self) -> str:
35 raise NotImplementedError
37 def embed(self, sentences: Union[Sentence, List[Sentence]]) -> List[Sentence]:
38 """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings
39 are non-static."""
41 # if only one sentence is passed, convert to list of sentence
42 if (type(sentences) is Sentence) or (type(sentences) is Image):
43 sentences = [sentences]
45 everything_embedded: bool = True
47 if self.embedding_type == "word-level":
48 for sentence in sentences:
49 for token in sentence.tokens:
50 if self.name not in token._embeddings.keys():
51 everything_embedded = False
52 break
53 else:
54 for sentence in sentences:
55 if self.name not in sentence._embeddings.keys():
56 everything_embedded = False
57 break
59 if not everything_embedded or not self.static_embeddings:
60 self._add_embeddings_internal(sentences)
62 return sentences
64 @abstractmethod
65 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
66 """Private method for adding embeddings to all words in a list of sentences."""
67 pass
69 def get_names(self) -> List[str]:
70 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of
71 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding).
72 Then, the list contains the names of all embeddings in the stack."""
73 return [self.name]
75 def get_named_embeddings_dict(self) -> Dict:
76 return {self.name: self}
79class ScalarMix(torch.nn.Module):
80 """
81 Computes a parameterised scalar mixture of N tensors.
82 This method was proposed by Liu et al. (2019) in the paper:
83 "Linguistic Knowledge and Transferability of Contextual Representations" (https://arxiv.org/abs/1903.08855)
85 The implementation is copied and slightly modified from the allennlp repository and is licensed under Apache 2.0.
86 It can be found under:
87 https://github.com/allenai/allennlp/blob/master/allennlp/modules/scalar_mix.py.
88 """
90 def __init__(self, mixture_size: int, trainable: bool = False) -> None:
91 """
92 Inits scalar mix implementation.
93 ``mixture = gamma * sum(s_k * tensor_k)`` where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters.
94 :param mixture_size: size of mixtures (usually the number of layers)
95 """
96 super(ScalarMix, self).__init__()
97 self.mixture_size = mixture_size
99 initial_scalar_parameters = [0.0] * mixture_size
101 self.scalar_parameters = ParameterList(
102 [
103 Parameter(
104 torch.tensor(
105 [initial_scalar_parameters[i]],
106 dtype=torch.float,
107 device=flair.device,
108 ),
109 requires_grad=trainable,
111 )
112 for i in range(mixture_size)
113 ]
114 )
115 self.gamma = Parameter(
116 torch.tensor(
117 [1.0],
118 dtype=torch.float,
119 device=flair.device,
120 ), requires_grad=trainable
121 )
123 def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
124 """
125 Computes a weighted average of the ``tensors``. The input tensors an be any shape
126 with at least two dimensions, but must all be the same shape.
127 :param tensors: list of input tensors
128 :return: computed weighted average of input tensors
129 """
130 if len(tensors) != self.mixture_size:
131 log.error(
132 "{} tensors were passed, but the module was initialized to mix {} tensors.".format(
133 len(tensors), self.mixture_size
134 )
135 )
137 normed_weights = torch.nn.functional.softmax(
138 torch.cat([parameter for parameter in self.scalar_parameters]), dim=0
139 )
140 normed_weights = torch.split(normed_weights, split_size_or_sections=1)
142 pieces = []
143 for weight, tensor in zip(normed_weights, tensors):
144 pieces.append(weight * tensor)
145 return self.gamma * sum(pieces)