Coverage for flair/flair/models/text_regression_model.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

113 statements  

1import logging 

2from pathlib import Path 

3from typing import List, Union, Optional 

4 

5import torch 

6import torch.nn as nn 

7from torch.utils.data.dataset import Dataset 

8 

9import flair 

10import flair.embeddings 

11from flair.data import Sentence, Label, DataPoint 

12from flair.datasets import DataLoader, SentenceDataset 

13from flair.training_utils import MetricRegression, Result, store_embeddings 

14 

15log = logging.getLogger("flair") 

16 

17 

18class TextRegressor(flair.nn.Model): 

19 

20 def __init__(self, document_embeddings: flair.embeddings.DocumentEmbeddings, label_name: str = 'label'): 

21 

22 super().__init__() 

23 log.info("Using REGRESSION - experimental") 

24 

25 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings 

26 self.label_name = label_name 

27 

28 self.decoder = nn.Linear(self.document_embeddings.embedding_length, 1) 

29 

30 nn.init.xavier_uniform_(self.decoder.weight) 

31 

32 self.loss_function = nn.MSELoss() 

33 

34 # auto-spawn on GPU if available 

35 self.to(flair.device) 

36 

37 def label_type(self): 

38 return self.label_name 

39 

40 def forward(self, sentences): 

41 

42 self.document_embeddings.embed(sentences) 

43 

44 embedding_names = self.document_embeddings.get_names() 

45 

46 text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] 

47 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device) 

48 

49 label_scores = self.decoder(text_embedding_tensor) 

50 

51 return label_scores 

52 

53 def forward_loss( 

54 self, data_points: Union[List[Sentence], Sentence] 

55 ) -> torch.tensor: 

56 

57 scores = self.forward(data_points) 

58 

59 return self._calculate_loss(scores, data_points) 

60 

61 def _labels_to_indices(self, sentences: List[Sentence]): 

62 indices = [ 

63 torch.tensor( 

64 [float(label.value) for label in sentence.labels], dtype=torch.float 

65 ) 

66 for sentence in sentences 

67 ] 

68 

69 vec = torch.cat(indices, 0).to(flair.device) 

70 

71 return vec 

72 

73 def predict( 

74 self, 

75 sentences: Union[Sentence, List[Sentence]], 

76 label_name: Optional[str] = None, 

77 mini_batch_size: int = 32, 

78 embedding_storage_mode="none", 

79 ) -> List[Sentence]: 

80 

81 if label_name == None: 

82 label_name = self.label_type if self.label_type is not None else 'label' 

83 

84 with torch.no_grad(): 

85 if type(sentences) is Sentence: 

86 sentences = [sentences] 

87 

88 filtered_sentences = self._filter_empty_sentences(sentences) 

89 

90 # remove previous embeddings 

91 store_embeddings(filtered_sentences, "none") 

92 

93 batches = [ 

94 filtered_sentences[x: x + mini_batch_size] 

95 for x in range(0, len(filtered_sentences), mini_batch_size) 

96 ] 

97 

98 for batch in batches: 

99 scores = self.forward(batch) 

100 

101 for (sentence, score) in zip(batch, scores.tolist()): 

102 sentence.set_label(label_name, value=str(score[0])) 

103 

104 # clearing token embeddings to save memory 

105 store_embeddings(batch, storage_mode=embedding_storage_mode) 

106 

107 return sentences 

108 

109 def _calculate_loss( 

110 self, scores: torch.tensor, sentences: List[Sentence] 

111 ) -> torch.tensor: 

112 """ 

113 Calculates the loss. 

114 :param scores: the prediction scores from the model 

115 :param sentences: list of sentences 

116 :return: loss value 

117 """ 

118 return self.loss_function(scores.squeeze(1), self._labels_to_indices(sentences)) 

119 

120 def forward_labels_and_loss( 

121 self, sentences: Union[Sentence, List[Sentence]] 

122 ) -> (List[List[float]], torch.tensor): 

123 

124 scores = self.forward(sentences) 

125 loss = self._calculate_loss(scores, sentences) 

126 return scores, loss 

127 

128 def evaluate( 

129 self, 

130 sentences: Union[List[DataPoint], Dataset], 

131 out_path: Union[str, Path] = None, 

132 embedding_storage_mode: str = "none", 

133 mini_batch_size: int = 32, 

134 num_workers: int = 8, 

135 **kwargs 

136 ) -> (Result, float): 

137 

138 # read Dataset into data loader (if list of sentences passed, make Dataset first) 

139 if not isinstance(sentences, Dataset): 

140 sentences = SentenceDataset(sentences) 

141 data_loader = DataLoader(sentences, batch_size=mini_batch_size, num_workers=num_workers) 

142 

143 with torch.no_grad(): 

144 eval_loss = 0 

145 

146 metric = MetricRegression("Evaluation") 

147 

148 lines: List[str] = [] 

149 total_count = 0 

150 for batch_nr, batch in enumerate(data_loader): 

151 

152 if isinstance(batch, Sentence): 

153 batch = [batch] 

154 

155 scores, loss = self.forward_labels_and_loss(batch) 

156 

157 true_values = [] 

158 for sentence in batch: 

159 total_count += 1 

160 for label in sentence.labels: 

161 true_values.append(float(label.value)) 

162 

163 results = [] 

164 for score in scores: 

165 if type(score[0]) is Label: 

166 results.append(float(score[0].score)) 

167 else: 

168 results.append(float(score[0])) 

169 

170 eval_loss += loss 

171 

172 metric.true.extend(true_values) 

173 metric.pred.extend(results) 

174 

175 for sentence, prediction, true_value in zip( 

176 batch, results, true_values 

177 ): 

178 eval_line = "{}\t{}\t{}\n".format( 

179 sentence.to_original_text(), true_value, prediction 

180 ) 

181 lines.append(eval_line) 

182 

183 store_embeddings(batch, embedding_storage_mode) 

184 

185 eval_loss /= total_count 

186 

187 ##TODO: not saving lines yet 

188 if out_path is not None: 

189 with open(out_path, "w", encoding="utf-8") as outfile: 

190 outfile.write("".join(lines)) 

191 

192 log_line = f"{metric.mean_squared_error()}\t{metric.spearmanr()}\t{metric.pearsonr()}" 

193 log_header = "MSE\tSPEARMAN\tPEARSON" 

194 

195 detailed_result = ( 

196 f"AVG: mse: {metric.mean_squared_error():.4f} - " 

197 f"mae: {metric.mean_absolute_error():.4f} - " 

198 f"pearson: {metric.pearsonr():.4f} - " 

199 f"spearman: {metric.spearmanr():.4f}" 

200 ) 

201 

202 result: Result = Result(main_score=metric.pearsonr(), 

203 loss=eval_loss, 

204 log_header=log_header, 

205 log_line=log_line, 

206 detailed_results=detailed_result, 

207 ) 

208 

209 return result 

210 

211 def _get_state_dict(self): 

212 model_state = { 

213 "state_dict": self.state_dict(), 

214 "document_embeddings": self.document_embeddings, 

215 "label_name": self.label_type, 

216 } 

217 return model_state 

218 

219 @staticmethod 

220 def _init_model_with_state_dict(state): 

221 

222 label_name = state["label_name"] if "label_name" in state.keys() else None 

223 

224 model = TextRegressor(document_embeddings=state["document_embeddings"], label_name=label_name) 

225 

226 model.load_state_dict(state["state_dict"]) 

227 return model 

228 

229 @staticmethod 

230 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: 

231 filtered_sentences = [sentence for sentence in sentences if sentence.tokens] 

232 if len(sentences) != len(filtered_sentences): 

233 log.warning( 

234 "Ignore {} sentence(s) with no tokens.".format( 

235 len(sentences) - len(filtered_sentences) 

236 ) 

237 ) 

238 return filtered_sentences