Coverage for flair/flair/models/text_regression_model.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from pathlib import Path
3from typing import List, Union, Optional
5import torch
6import torch.nn as nn
7from torch.utils.data.dataset import Dataset
9import flair
10import flair.embeddings
11from flair.data import Sentence, Label, DataPoint
12from flair.datasets import DataLoader, SentenceDataset
13from flair.training_utils import MetricRegression, Result, store_embeddings
15log = logging.getLogger("flair")
18class TextRegressor(flair.nn.Model):
20 def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'):
22 super().__init__()
23 log.info("Using REGRESSION - experimental")
25 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
26 self.label_name = label_name
28 self.decoder = nn.Linear(self.document_embeddings.embedding_length, 1)
30 nn.init.xavier_uniform_(self.decoder.weight)
32 self.loss_function = nn.MSELoss()
34 # auto-spawn on GPU if available
35 self.to(flair.device)
37 def label_type(self):
38 return self.label_name
40 def forward(self, sentences):
42 self.document_embeddings.embed(sentences)
44 embedding_names = self.document_embeddings.get_names()
46 text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences]
47 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device)
49 label_scores = self.decoder(text_embedding_tensor)
51 return label_scores
53 def forward_loss(
54 self, data_points: Union[List[Sentence], Sentence]
55 ) -> torch.tensor:
57 scores = self.forward(data_points)
59 return self._calculate_loss(scores, data_points)
61 def _labels_to_indices(self, sentences: List[Sentence]):
62 indices = [
63 torch.tensor(
64 [float(label.value) for label in sentence.labels], dtype=torch.float
65 )
66 for sentence in sentences
67 ]
69 vec = torch.cat(indices, 0).to(flair.device)
71 return vec
73 def predict(
74 self,
75 sentences: Union[Sentence, List[Sentence]],
76 label_name: Optional[str] = None,
77 mini_batch_size: int = 32,
78 embedding_storage_mode="none",
79 ) -> List[Sentence]:
81 if label_name == None:
82 label_name = self.label_type if self.label_type is not None else 'label'
84 with torch.no_grad():
85 if type(sentences) is Sentence:
86 sentences = [sentences]
88 filtered_sentences = self._filter_empty_sentences(sentences)
90 # remove previous embeddings
91 store_embeddings(filtered_sentences, "none")
93 batches = [
94 filtered_sentences[x: x + mini_batch_size]
95 for x in range(0, len(filtered_sentences), mini_batch_size)
96 ]
98 for batch in batches:
99 scores = self.forward(batch)
101 for (sentence, score) in zip(batch, scores.tolist()):
102 sentence.set_label(label_name, value=str(score[0]))
104 # clearing token embeddings to save memory
105 store_embeddings(batch, storage_mode=embedding_storage_mode)
107 return sentences
109 def _calculate_loss(
110 self, scores: torch.tensor, sentences: List[Sentence]
111 ) -> torch.tensor:
112 """
113 Calculates the loss.
114 :param scores: the prediction scores from the model
115 :param sentences: list of sentences
116 :return: loss value
117 """
118 return self.loss_function(scores.squeeze(1), self._labels_to_indices(sentences))
120 def forward_labels_and_loss(
121 self, sentences: Union[Sentence, List[Sentence]]
122 ) -> (List[List[float]], torch.tensor):
124 scores = self.forward(sentences)
125 loss = self._calculate_loss(scores, sentences)
126 return scores, loss
128 def evaluate(
129 self,
130 sentences: Union[List[DataPoint], Dataset],
131 out_path: Union[str, Path] = None,
132 embedding_storage_mode: str = "none",
133 mini_batch_size: int = 32,
134 num_workers: int = 8,
135 **kwargs
136 ) -> (Result, float):
138 # read Dataset into data loader (if list of sentences passed, make Dataset first)
139 if not isinstance(sentences, Dataset):
140 sentences = SentenceDataset(sentences)
141 data_loader = DataLoader(sentences, batch_size=mini_batch_size, num_workers=num_workers)
143 with torch.no_grad():
144 eval_loss = 0
146 metric = MetricRegression("Evaluation")
148 lines: List[str] = []
149 total_count = 0
150 for batch_nr, batch in enumerate(data_loader):
152 if isinstance(batch, Sentence):
153 batch = [batch]
155 scores, loss = self.forward_labels_and_loss(batch)
157 true_values = []
158 for sentence in batch:
159 total_count += 1
160 for label in sentence.labels:
161 true_values.append(float(label.value))
163 results = []
164 for score in scores:
165 if type(score[0]) is Label:
166 results.append(float(score[0].score))
167 else:
168 results.append(float(score[0]))
170 eval_loss += loss
172 metric.true.extend(true_values)
173 metric.pred.extend(results)
175 for sentence, prediction, true_value in zip(
176 batch, results, true_values
177 ):
178 eval_line = "{}\t{}\t{}\n".format(
179 sentence.to_original_text(), true_value, prediction
180 )
181 lines.append(eval_line)
183 store_embeddings(batch, embedding_storage_mode)
185 eval_loss /= total_count
187 ##TODO: not saving lines yet
188 if out_path is not None:
189 with open(out_path, "w", encoding="utf-8") as outfile:
190 outfile.write("".join(lines))
192 log_line = f"{metric.mean_squared_error()}\t{metric.spearmanr()}\t{metric.pearsonr()}"
193 log_header = "MSE\tSPEARMAN\tPEARSON"
195 detailed_result = (
196 f"AVG: mse: {metric.mean_squared_error():.4f} - "
197 f"mae: {metric.mean_absolute_error():.4f} - "
198 f"pearson: {metric.pearsonr():.4f} - "
199 f"spearman: {metric.spearmanr():.4f}"
200 )
202 result: Result = Result(main_score=metric.pearsonr(),
203 loss=eval_loss,
204 log_header=log_header,
205 log_line=log_line,
206 detailed_results=detailed_result,
207 )
209 return result
211 def _get_state_dict(self):
212 model_state = {
213 "state_dict": self.state_dict(),
214 "document_embeddings": self.document_embeddings,
215 "label_name": self.label_type,
216 }
217 return model_state
219 @staticmethod
220 def _init_model_with_state_dict(state):
222 label_name = state["label_name"] if "label_name" in state.keys() else None
224 model = TextRegressor(document_embeddings=state["document_embeddings"], label_name=label_name)
226 model.load_state_dict(state["state_dict"])
227 return model
229 @staticmethod
230 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
231 filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
232 if len(sentences) != len(filtered_sentences):
233 log.warning(
234 "Ignore {} sentence(s) with no tokens.".format(
235 len(sentences) - len(filtered_sentences)
236 )
237 )
238 return filtered_sentences