Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/models/similarity_learning_model.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

190 statements  

1from abc import abstractmethod 

2 

3import flair 

4from flair.data import DataPoint, DataPair 

5from flair.embeddings import Embeddings 

6from flair.datasets import DataLoader 

7from flair.training_utils import Result 

8from flair.training_utils import store_embeddings 

9 

10import torch 

11from torch import nn 

12import torch.nn.functional as F 

13 

14import numpy as np 

15 

16import itertools 

17 

18from typing import Union, List 

19from pathlib import Path 

20 

21 

22# == similarity measures == 

23class SimilarityMeasure: 

24 @abstractmethod 

25 def forward(self, x): 

26 pass 

27 

28 

29# helper class for ModelSimilarity 

30class SliceReshaper(flair.nn.Model): 

31 def __init__(self, begin, end=None, shape=None): 

32 super(SliceReshaper, self).__init__() 

33 self.begin = begin 

34 self.end = end 

35 self.shape = shape 

36 

37 def forward(self, x): 

38 x = x[:, self.begin] if self.end is None else x[:, self.begin : self.end] 

39 x = x.view(-1, *self.shape) if self.shape is not None else x 

40 return x 

41 

42 

43# -- works with binary cross entropy loss -- 

44class ModelSimilarity(SimilarityMeasure): 

45 """ 

46 Similarity defined by the model. The model parameters are given by the first element of the pair. 

47 The similarity is evaluated by doing the forward pass (inference) on the parametrized model with 

48 the second element of the pair as input. 

49 """ 

50 

51 def __init__(self, model): 

52 # model is a list of tuples (function, parameters), where parameters is a dict {param_name: param_extract_model} 

53 self.model = model 

54 

55 def forward(self, x): 

56 

57 model_parameters = x[0] 

58 model_inputs = x[1] 

59 

60 cur_outputs = model_inputs 

61 for layer_model, parameter_map in self.model: 

62 param_dict = {} 

63 for param_name, param_slice_reshape in parameter_map.items(): 

64 if isinstance(param_slice_reshape, SliceReshaper): 

65 val = param_slice_reshape(model_parameters) 

66 else: 

67 val = param_slice_reshape 

68 param_dict[param_name] = val 

69 cur_outputs = layer_model(cur_outputs, **param_dict) 

70 

71 return cur_outputs 

72 

73 

74# -- works with ranking/triplet loss -- 

75class CosineSimilarity(SimilarityMeasure): 

76 """ 

77 Similarity defined by the cosine distance. 

78 """ 

79 

80 def forward(self, x): 

81 input_modality_0 = x[0] 

82 input_modality_1 = x[1] 

83 

84 # normalize the embeddings 

85 input_modality_0_norms = torch.norm(input_modality_0, dim=-1, keepdim=True) 

86 input_modality_1_norms = torch.norm(input_modality_1, dim=-1, keepdim=True) 

87 

88 return torch.matmul( 

89 input_modality_0 / input_modality_0_norms, 

90 (input_modality_1 / input_modality_1_norms).t(), 

91 ) 

92 

93 

94# == similarity losses == 

95class SimilarityLoss(nn.Module): 

96 def __init__(self): 

97 super(SimilarityLoss, self).__init__() 

98 

99 @abstractmethod 

100 def forward(self, inputs, targets): 

101 pass 

102 

103 

104class PairwiseBCELoss(SimilarityLoss): 

105 """ 

106 Binary cross entropy between pair similarities and pair labels. 

107 """ 

108 

109 def __init__(self, balanced=False): 

110 super(PairwiseBCELoss, self).__init__() 

111 self.balanced = balanced 

112 

113 def forward(self, inputs, targets): 

114 n = inputs.shape[0] 

115 neg_targets = torch.ones_like(targets).to(flair.device) - targets 

116 # we want that logits for corresponding pairs are high, and for non-corresponding low 

117 bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 

118 if self.balanced: 

119 # TODO: this assumes eye matrix 

120 weight_matrix = n * (targets / 2.0 + neg_targets / (2.0 * (n - 1))) 

121 bce_loss *= weight_matrix 

122 loss = bce_loss.mean() 

123 

124 return loss 

125 

126 

127class RankingLoss(SimilarityLoss): 

128 """ 

129 Triplet ranking loss between pair similarities and pair labels. 

130 """ 

131 

132 def __init__(self, margin=0.1, direction_weights=[0.5, 0.5]): 

133 super(RankingLoss, self).__init__() 

134 self.margin = margin 

135 self.direction_weights = direction_weights 

136 

137 def forward(self, inputs, targets): 

138 n = inputs.shape[0] 

139 neg_targets = torch.ones_like(targets) - targets 

140 # loss matrices for two directions of alignment, from modality 0 => modality 1 and vice versa 

141 ranking_loss_matrix_01 = neg_targets * F.relu( 

142 self.margin + inputs - torch.diag(inputs).view(n, 1) 

143 ) 

144 ranking_loss_matrix_10 = neg_targets * F.relu( 

145 self.margin + inputs - torch.diag(inputs).view(1, n) 

146 ) 

147 neg_targets_01_sum = torch.sum(neg_targets, dim=1) 

148 neg_targets_10_sum = torch.sum(neg_targets, dim=0) 

149 loss = self.direction_weights[0] * torch.mean( 

150 torch.sum(ranking_loss_matrix_01 / neg_targets_01_sum, dim=1) 

151 ) + self.direction_weights[1] * torch.mean( 

152 torch.sum(ranking_loss_matrix_10 / neg_targets_10_sum, dim=0) 

153 ) 

154 

155 return loss 

156 

157 

158# == similarity learner == 

159class SimilarityLearner(flair.nn.Model): 

160 def __init__( 

161 self, 

162 source_embeddings: Embeddings, 

163 target_embeddings: Embeddings, 

164 similarity_measure: SimilarityMeasure, 

165 similarity_loss: SimilarityLoss, 

166 eval_device=flair.device, 

167 source_mapping: torch.nn.Module = None, 

168 target_mapping: torch.nn.Module = None, 

169 recall_at_points: List[int] = [1, 5, 10, 20], 

170 recall_at_points_weights: List[float] = [0.4, 0.3, 0.2, 0.1], 

171 interleave_embedding_updates: bool = False, 

172 ): 

173 super(SimilarityLearner, self).__init__() 

174 self.source_embeddings: Embeddings = source_embeddings 

175 self.target_embeddings: Embeddings = target_embeddings 

176 self.source_mapping: torch.nn.Module = source_mapping 

177 self.target_mapping: torch.nn.Module = target_mapping 

178 self.similarity_measure: SimilarityMeasure = similarity_measure 

179 self.similarity_loss: SimilarityLoss = similarity_loss 

180 self.eval_device = eval_device 

181 self.recall_at_points: List[int] = recall_at_points 

182 self.recall_at_points_weights: List[float] = recall_at_points_weights 

183 self.interleave_embedding_updates = interleave_embedding_updates 

184 

185 self.to(flair.device) 

186 

187 def _embed_source(self, data_points): 

188 

189 if type(data_points[0]) == DataPair: 

190 data_points = [point.first for point in data_points] 

191 

192 self.source_embeddings.embed(data_points) 

193 

194 source_embedding_tensor = torch.stack( 

195 [point.embedding for point in data_points] 

196 ).to(flair.device) 

197 

198 if self.source_mapping is not None: 

199 source_embedding_tensor = self.source_mapping(source_embedding_tensor) 

200 

201 return source_embedding_tensor 

202 

203 def _embed_target(self, data_points): 

204 

205 if type(data_points[0]) == DataPair: 

206 data_points = [point.second for point in data_points] 

207 

208 self.target_embeddings.embed(data_points) 

209 

210 target_embedding_tensor = torch.stack( 

211 [point.embedding for point in data_points] 

212 ).to(flair.device) 

213 

214 if self.target_mapping is not None: 

215 target_embedding_tensor = self.target_mapping(target_embedding_tensor) 

216 

217 return target_embedding_tensor 

218 

219 def get_similarity(self, modality_0_embeddings, modality_1_embeddings): 

220 """ 

221 :param modality_0_embeddings: embeddings of first modality, a tensor of shape [n0, d0] 

222 :param modality_1_embeddings: embeddings of second modality, a tensor of shape [n1, d1] 

223 :return: a similarity matrix of shape [n0, n1] 

224 """ 

225 return self.similarity_measure.forward( 

226 [modality_0_embeddings, modality_1_embeddings] 

227 ) 

228 

229 def forward_loss( 

230 self, data_points: Union[List[DataPoint], DataPoint] 

231 ) -> torch.tensor: 

232 mapped_source_embeddings = self._embed_source(data_points) 

233 mapped_target_embeddings = self._embed_target(data_points) 

234 

235 if self.interleave_embedding_updates: 

236 # 1/3 only source branch of model, 1/3 only target branch of model, 1/3 both 

237 detach_modality_id = torch.randint(0, 3, (1,)).item() 

238 if detach_modality_id == 0: 

239 mapped_source_embeddings.detach() 

240 elif detach_modality_id == 1: 

241 mapped_target_embeddings.detach() 

242 

243 similarity_matrix = self.similarity_measure.forward( 

244 (mapped_source_embeddings, mapped_target_embeddings) 

245 ) 

246 

247 def add_to_index_map(hashmap, key, val): 

248 if key not in hashmap: 

249 hashmap[key] = [val] 

250 else: 

251 hashmap[key] += [val] 

252 

253 index_map = {"first": {}, "second": {}} 

254 for data_point_id, data_point in enumerate(data_points): 

255 add_to_index_map(index_map["first"], str(data_point.first), data_point_id) 

256 add_to_index_map(index_map["second"], str(data_point.second), data_point_id) 

257 

258 targets = torch.zeros_like(similarity_matrix).to(flair.device) 

259 

260 for data_point in data_points: 

261 first_indices = index_map["first"][str(data_point.first)] 

262 second_indices = index_map["second"][str(data_point.second)] 

263 for first_index, second_index in itertools.product( 

264 first_indices, second_indices 

265 ): 

266 targets[first_index, second_index] = 1.0 

267 

268 loss = self.similarity_loss(similarity_matrix, targets) 

269 

270 return loss 

271 

272 def evaluate( 

273 self, 

274 data_pairs: DataPair, 

275 out_path: Path = None, 

276 embedding_storage_mode="none", 

277 mini_batch_size=32, 

278 num_workers=8, 

279 **kwargs 

280 ) -> (Result, float): 

281 # assumes that for each data pair there's at least one embedding per modality 

282 

283 data_loader = DataLoader(data_pairs, batch_size=mini_batch_size, num_workers=num_workers) 

284 

285 with torch.no_grad(): 

286 # pre-compute embeddings for all targets in evaluation dataset 

287 target_index = {} 

288 all_target_embeddings = [] 

289 for data_points in data_loader: 

290 target_inputs = [] 

291 for data_point in data_points: 

292 if str(data_point.second) not in target_index: 

293 target_index[str(data_point.second)] = len(target_index) 

294 target_inputs.append(data_point) 

295 if target_inputs: 

296 all_target_embeddings.append( 

297 self._embed_target(target_inputs).to(self.eval_device) 

298 ) 

299 store_embeddings(data_points, embedding_storage_mode) 

300 all_target_embeddings = torch.cat(all_target_embeddings, dim=0) # [n0, d0] 

301 assert len(target_index) == all_target_embeddings.shape[0] 

302 

303 ranks = [] 

304 for data_points in data_loader: 

305 batch_embeddings = self._embed_source(data_points) 

306 

307 batch_source_embeddings = batch_embeddings.to(self.eval_device) 

308 # compute the similarity 

309 batch_similarity_matrix = self.similarity_measure.forward( 

310 [batch_source_embeddings, all_target_embeddings] 

311 ) 

312 

313 # sort the similarity matrix across modality 1 

314 batch_modality_1_argsort = torch.argsort( 

315 batch_similarity_matrix, descending=True, dim=1 

316 ) 

317 

318 # get the ranks, so +1 to start counting ranks from 1 

319 batch_modality_1_ranks = ( 

320 torch.argsort(batch_modality_1_argsort, dim=1) + 1 

321 ) 

322 

323 batch_target_indices = [ 

324 target_index[str(data_point.second)] for data_point in data_points 

325 ] 

326 

327 batch_gt_ranks = batch_modality_1_ranks[ 

328 torch.arange(batch_similarity_matrix.shape[0]), 

329 torch.tensor(batch_target_indices), 

330 ] 

331 ranks.extend(batch_gt_ranks.tolist()) 

332 

333 store_embeddings(data_points, embedding_storage_mode) 

334 

335 ranks = np.array(ranks) 

336 median_rank = np.median(ranks) 

337 recall_at = {k: np.mean(ranks <= k) for k in self.recall_at_points} 

338 

339 results_header = ["Median rank"] + [ 

340 "Recall@top" + str(r) for r in self.recall_at_points 

341 ] 

342 results_header_str = "\t".join(results_header) 

343 epoch_results = [str(median_rank)] + [ 

344 str(recall_at[k]) for k in self.recall_at_points 

345 ] 

346 epoch_results_str = "\t".join(epoch_results) 

347 detailed_results = ", ".join( 

348 [f"{h}={v}" for h, v in zip(results_header, epoch_results)] 

349 ) 

350 

351 validated_measure = sum( 

352 [ 

353 recall_at[r] * w 

354 for r, w in zip(self.recall_at_points, self.recall_at_points_weights) 

355 ] 

356 ) 

357 

358 return ( 

359 Result( 

360 validated_measure, 

361 results_header_str, 

362 epoch_results_str, 

363 detailed_results, 

364 ), 

365 torch.tensor(0), 

366 ) 

367 

368 def _get_state_dict(self): 

369 model_state = { 

370 "state_dict": self.state_dict(), 

371 "input_modality_0_embedding": self.source_embeddings, 

372 "input_modality_1_embedding": self.target_embeddings, 

373 "similarity_measure": self.similarity_measure, 

374 "similarity_loss": self.similarity_loss, 

375 "source_mapping": self.source_mapping, 

376 "target_mapping": self.target_mapping, 

377 "eval_device": self.eval_device, 

378 "recall_at_points": self.recall_at_points, 

379 "recall_at_points_weights": self.recall_at_points_weights, 

380 } 

381 return model_state 

382 

383 @staticmethod 

384 def _init_model_with_state_dict(state): 

385 # The conversion from old model's constructor interface 

386 if "input_embeddings" in state: 

387 state["input_modality_0_embedding"] = state["input_embeddings"][0] 

388 state["input_modality_1_embedding"] = state["input_embeddings"][1] 

389 model = SimilarityLearner( 

390 source_embeddings=state["input_modality_0_embedding"], 

391 target_embeddings=state["input_modality_1_embedding"], 

392 source_mapping=state["source_mapping"], 

393 target_mapping=state["target_mapping"], 

394 similarity_measure=state["similarity_measure"], 

395 similarity_loss=state["similarity_loss"], 

396 eval_device=state["eval_device"], 

397 recall_at_points=state["recall_at_points"], 

398 recall_at_points_weights=state["recall_at_points_weights"], 

399 ) 

400 

401 model.load_state_dict(state["state_dict"]) 

402 return model