Coverage for flair/flair/models/pairwise_classification_model.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Union, List
3import torch
5import flair.embeddings
6import flair.nn
7from flair.data import Label, DataPoint, Sentence, DataPair
10class TextPairClassifier(flair.nn.DefaultClassifier):
11 """
12 Text Pair Classification Model for tasks such as Recognizing Textual Entailment, build upon TextClassifier.
13 The model takes document embeddings and puts resulting text representation(s) into a linear layer to get the
14 actual class label. We provide two ways to embed the DataPairs: Either by embedding both DataPoints
15 and concatenating the resulting vectors ("embed_separately=True") or by concatenating the DataPoints and embedding
16 the resulting vector ("embed_separately=False").
17 """
19 def __init__(
20 self,
21 document_embeddings: flair.embeddings.DocumentEmbeddings,
22 label_type: str,
23 embed_separately: bool = False,
24 **classifierargs,
25 ):
26 """
27 Initializes a TextClassifier
28 :param document_embeddings: embeddings used to embed each data point
29 :param label_dictionary: dictionary of labels you want to predict
30 :param multi_label: auto-detected by default, but you can set this to True to force multi-label prediction
31 or False to force single-label prediction
32 :param multi_label_threshold: If multi-label you can set the threshold to make predictions
33 :param loss_weights: Dictionary of weights for labels for the loss function
34 (if any label's weight is unspecified it will default to 1.0)
35 """
36 super().__init__(**classifierargs)
38 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
40 self._label_type = label_type
42 self.embed_separately = embed_separately
44 # if embed_separately == True the linear layer needs twice the length of the embeddings as input size
45 # since we concatenate the embeddings of the two DataPoints in the DataPairs
46 if self.embed_separately:
47 self.decoder = torch.nn.Linear(
48 2 * self.document_embeddings.embedding_length, len(self.label_dictionary)
49 ).to(flair.device)
51 torch.nn.init.xavier_uniform_(self.decoder.weight)
53 else:
54 # representation for both sentences
55 self.decoder = torch.nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary))
57 # set separator to concatenate two sentences
58 self.sep = ' '
59 if isinstance(self.document_embeddings, flair.embeddings.document.TransformerDocumentEmbeddings):
60 if self.document_embeddings.tokenizer.sep_token:
61 self.sep = ' ' + str(self.document_embeddings.tokenizer.sep_token) + ' '
62 else:
63 self.sep = ' [SEP] '
65 torch.nn.init.xavier_uniform_(self.decoder.weight)
67 # auto-spawn on GPU if available
68 self.to(flair.device)
70 @property
71 def label_type(self):
72 return self._label_type
74 def forward_pass(self,
75 datapairs: Union[List[DataPoint], DataPoint],
76 return_label_candidates: bool = False,
77 ):
79 if isinstance(datapairs, DataPair):
80 datapairs = [datapairs]
82 embedding_names = self.document_embeddings.get_names()
84 if self.embed_separately: # embed both sentences seperately, concatenate the resulting vectors
85 first_elements = [pair.first for pair in datapairs]
86 second_elements = [pair.second for pair in datapairs]
88 self.document_embeddings.embed(first_elements)
89 self.document_embeddings.embed(second_elements)
91 text_embedding_list = [
92 torch.cat([a.get_embedding(embedding_names), b.get_embedding(embedding_names)], 0).unsqueeze(0)
93 for (a, b) in zip(first_elements, second_elements)
94 ]
96 else: # concatenate the sentences and embed together
97 concatenated_sentences = [
98 Sentence(
99 pair.first.to_tokenized_string() + self.sep + pair.second.to_tokenized_string(),
100 use_tokenizer=False
101 )
102 for pair in datapairs]
104 self.document_embeddings.embed(concatenated_sentences)
106 text_embedding_list = [
107 sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in concatenated_sentences
108 ]
110 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device)
112 # linear layer
113 scores = self.decoder(text_embedding_tensor)
115 labels = []
116 for pair in datapairs:
117 labels.append([label.value for label in pair.get_labels(self.label_type)])
119 # minimal return is scores and labels
120 return_tuple = (scores, labels)
122 if return_label_candidates:
123 label_candidates = [Label(value=None) for pair in datapairs]
124 return_tuple += (datapairs, label_candidates)
126 return return_tuple
128 def _get_state_dict(self):
129 model_state = {
130 "state_dict": self.state_dict(),
131 "document_embeddings": self.document_embeddings,
132 "label_dictionary": self.label_dictionary,
133 "label_type": self.label_type,
134 "multi_label": self.multi_label,
135 "multi_label_threshold": self.multi_label_threshold,
136 "loss_weights": self.loss_weights,
137 "embed_separately": self.embed_separately,
138 }
139 return model_state
141 @staticmethod
142 def _init_model_with_state_dict(state):
144 model = TextPairClassifier(
145 document_embeddings=state["document_embeddings"],
146 label_dictionary=state["label_dictionary"],
147 label_type=state["label_type"],
148 multi_label=state["multi_label"],
149 multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
150 loss_weights=state["loss_weights"],
151 embed_separately=state["embed_separately"],
152 )
153 model.load_state_dict(state["state_dict"])
154 return model