Coverage for flair/flair/models/similarity_learning_model.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from abc import abstractmethod
3import flair
4from flair.data import DataPoint, DataPair
5from flair.embeddings import Embeddings
6from flair.datasets import DataLoader
7from flair.training_utils import Result
8from flair.training_utils import store_embeddings
10import torch
11from torch import nn
12import torch.nn.functional as F
14import numpy as np
16import itertools
18from typing import Union, List
19from pathlib import Path
22# == similarity measures ==
23class SimilarityMeasure:
24 @abstractmethod
25 def forward(self, x):
26 pass
29# helper class for ModelSimilarity
30class SliceReshaper(flair.nn.Model):
31 def __init__(self, begin, end=None, shape=None):
32 super(SliceReshaper, self).__init__()
33 self.begin = begin
34 self.end = end
35 self.shape = shape
37 def forward(self, x):
38 x = x[:, self.begin] if self.end is None else x[:, self.begin : self.end]
39 x = x.view(-1, *self.shape) if self.shape is not None else x
40 return x
43# -- works with binary cross entropy loss --
44class ModelSimilarity(SimilarityMeasure):
45 """
46 Similarity defined by the model. The model parameters are given by the first element of the pair.
47 The similarity is evaluated by doing the forward pass (inference) on the parametrized model with
48 the second element of the pair as input.
49 """
51 def __init__(self, model):
52 # model is a list of tuples (function, parameters), where parameters is a dict {param_name: param_extract_model}
53 self.model = model
55 def forward(self, x):
57 model_parameters = x[0]
58 model_inputs = x[1]
60 cur_outputs = model_inputs
61 for layer_model, parameter_map in self.model:
62 param_dict = {}
63 for param_name, param_slice_reshape in parameter_map.items():
64 if isinstance(param_slice_reshape, SliceReshaper):
65 val = param_slice_reshape(model_parameters)
66 else:
67 val = param_slice_reshape
68 param_dict[param_name] = val
69 cur_outputs = layer_model(cur_outputs, **param_dict)
71 return cur_outputs
74# -- works with ranking/triplet loss --
75class CosineSimilarity(SimilarityMeasure):
76 """
77 Similarity defined by the cosine distance.
78 """
80 def forward(self, x):
81 input_modality_0 = x[0]
82 input_modality_1 = x[1]
84 # normalize the embeddings
85 input_modality_0_norms = torch.norm(input_modality_0, dim=-1, keepdim=True)
86 input_modality_1_norms = torch.norm(input_modality_1, dim=-1, keepdim=True)
88 return torch.matmul(
89 input_modality_0 / input_modality_0_norms,
90 (input_modality_1 / input_modality_1_norms).t(),
91 )
94# == similarity losses ==
95class SimilarityLoss(nn.Module):
96 def __init__(self):
97 super(SimilarityLoss, self).__init__()
99 @abstractmethod
100 def forward(self, inputs, targets):
101 pass
104class PairwiseBCELoss(SimilarityLoss):
105 """
106 Binary cross entropy between pair similarities and pair labels.
107 """
109 def __init__(self, balanced=False):
110 super(PairwiseBCELoss, self).__init__()
111 self.balanced = balanced
113 def forward(self, inputs, targets):
114 n = inputs.shape[0]
115 neg_targets = torch.ones_like(targets).to(flair.device) - targets
116 # we want that logits for corresponding pairs are high, and for non-corresponding low
117 bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
118 if self.balanced:
119 # TODO: this assumes eye matrix
120 weight_matrix = n * (targets / 2.0 + neg_targets / (2.0 * (n - 1)))
121 bce_loss *= weight_matrix
122 loss = bce_loss.mean()
124 return loss
127class RankingLoss(SimilarityLoss):
128 """
129 Triplet ranking loss between pair similarities and pair labels.
130 """
132 def __init__(self, margin=0.1, direction_weights=[0.5, 0.5]):
133 super(RankingLoss, self).__init__()
134 self.margin = margin
135 self.direction_weights = direction_weights
137 def forward(self, inputs, targets):
138 n = inputs.shape[0]
139 neg_targets = torch.ones_like(targets) - targets
140 # loss matrices for two directions of alignment, from modality 0 => modality 1 and vice versa
141 ranking_loss_matrix_01 = neg_targets * F.relu(
142 self.margin + inputs - torch.diag(inputs).view(n, 1)
143 )
144 ranking_loss_matrix_10 = neg_targets * F.relu(
145 self.margin + inputs - torch.diag(inputs).view(1, n)
146 )
147 neg_targets_01_sum = torch.sum(neg_targets, dim=1)
148 neg_targets_10_sum = torch.sum(neg_targets, dim=0)
149 loss = self.direction_weights[0] * torch.mean(
150 torch.sum(ranking_loss_matrix_01 / neg_targets_01_sum, dim=1)
151 ) + self.direction_weights[1] * torch.mean(
152 torch.sum(ranking_loss_matrix_10 / neg_targets_10_sum, dim=0)
153 )
155 return loss
158# == similarity learner ==
159class SimilarityLearner(flair.nn.Model):
160 def __init__(
161 self,
162 source_embeddings: Embeddings,
163 target_embeddings: Embeddings,
164 similarity_measure: SimilarityMeasure,
165 similarity_loss: SimilarityLoss,
166 eval_device=flair.device,
167 source_mapping: torch.nn.Module = None,
168 target_mapping: torch.nn.Module = None,
169 recall_at_points: List[int] = [1, 5, 10, 20],
170 recall_at_points_weights: List[float] = [0.4, 0.3, 0.2, 0.1],
171 interleave_embedding_updates: bool = False,
172 ):
173 super(SimilarityLearner, self).__init__()
174 self.source_embeddings: Embeddings = source_embeddings
175 self.target_embeddings: Embeddings = target_embeddings
176 self.source_mapping: torch.nn.Module = source_mapping
177 self.target_mapping: torch.nn.Module = target_mapping
178 self.similarity_measure: SimilarityMeasure = similarity_measure
179 self.similarity_loss: SimilarityLoss = similarity_loss
180 self.eval_device = eval_device
181 self.recall_at_points: List[int] = recall_at_points
182 self.recall_at_points_weights: List[float] = recall_at_points_weights
183 self.interleave_embedding_updates = interleave_embedding_updates
185 self.to(flair.device)
187 def _embed_source(self, data_points):
189 if type(data_points[0]) == DataPair:
190 data_points = [point.first for point in data_points]
192 self.source_embeddings.embed(data_points)
194 source_embedding_tensor = torch.stack(
195 [point.embedding for point in data_points]
196 ).to(flair.device)
198 if self.source_mapping is not None:
199 source_embedding_tensor = self.source_mapping(source_embedding_tensor)
201 return source_embedding_tensor
203 def _embed_target(self, data_points):
205 if type(data_points[0]) == DataPair:
206 data_points = [point.second for point in data_points]
208 self.target_embeddings.embed(data_points)
210 target_embedding_tensor = torch.stack(
211 [point.embedding for point in data_points]
212 ).to(flair.device)
214 if self.target_mapping is not None:
215 target_embedding_tensor = self.target_mapping(target_embedding_tensor)
217 return target_embedding_tensor
219 def get_similarity(self, modality_0_embeddings, modality_1_embeddings):
220 """
221 :param modality_0_embeddings: embeddings of first modality, a tensor of shape [n0, d0]
222 :param modality_1_embeddings: embeddings of second modality, a tensor of shape [n1, d1]
223 :return: a similarity matrix of shape [n0, n1]
224 """
225 return self.similarity_measure.forward(
226 [modality_0_embeddings, modality_1_embeddings]
227 )
229 def forward_loss(
230 self, data_points: Union[List[DataPoint], DataPoint]
231 ) -> torch.tensor:
232 mapped_source_embeddings = self._embed_source(data_points)
233 mapped_target_embeddings = self._embed_target(data_points)
235 if self.interleave_embedding_updates:
236 # 1/3 only source branch of model, 1/3 only target branch of model, 1/3 both
237 detach_modality_id = torch.randint(0, 3, (1,)).item()
238 if detach_modality_id == 0:
239 mapped_source_embeddings.detach()
240 elif detach_modality_id == 1:
241 mapped_target_embeddings.detach()
243 similarity_matrix = self.similarity_measure.forward(
244 (mapped_source_embeddings, mapped_target_embeddings)
245 )
247 def add_to_index_map(hashmap, key, val):
248 if key not in hashmap:
249 hashmap[key] = [val]
250 else:
251 hashmap[key] += [val]
253 index_map = {"first": {}, "second": {}}
254 for data_point_id, data_point in enumerate(data_points):
255 add_to_index_map(index_map["first"], str(data_point.first), data_point_id)
256 add_to_index_map(index_map["second"], str(data_point.second), data_point_id)
258 targets = torch.zeros_like(similarity_matrix).to(flair.device)
260 for data_point in data_points:
261 first_indices = index_map["first"][str(data_point.first)]
262 second_indices = index_map["second"][str(data_point.second)]
263 for first_index, second_index in itertools.product(
264 first_indices, second_indices
265 ):
266 targets[first_index, second_index] = 1.0
268 loss = self.similarity_loss(similarity_matrix, targets)
270 return loss
272 def evaluate(
273 self,
274 data_pairs: DataPair,
275 out_path: Path = None,
276 embedding_storage_mode="none",
277 mini_batch_size=32,
278 num_workers=8,
279 **kwargs
280 ) -> (Result, float):
281 # assumes that for each data pair there's at least one embedding per modality
283 data_loader = DataLoader(data_pairs, batch_size=mini_batch_size, num_workers=num_workers)
285 with torch.no_grad():
286 # pre-compute embeddings for all targets in evaluation dataset
287 target_index = {}
288 all_target_embeddings = []
289 for data_points in data_loader:
290 target_inputs = []
291 for data_point in data_points:
292 if str(data_point.second) not in target_index:
293 target_index[str(data_point.second)] = len(target_index)
294 target_inputs.append(data_point)
295 if target_inputs:
296 all_target_embeddings.append(
297 self._embed_target(target_inputs).to(self.eval_device)
298 )
299 store_embeddings(data_points, embedding_storage_mode)
300 all_target_embeddings = torch.cat(all_target_embeddings, dim=0) # [n0, d0]
301 assert len(target_index) == all_target_embeddings.shape[0]
303 ranks = []
304 for data_points in data_loader:
305 batch_embeddings = self._embed_source(data_points)
307 batch_source_embeddings = batch_embeddings.to(self.eval_device)
308 # compute the similarity
309 batch_similarity_matrix = self.similarity_measure.forward(
310 [batch_source_embeddings, all_target_embeddings]
311 )
313 # sort the similarity matrix across modality 1
314 batch_modality_1_argsort = torch.argsort(
315 batch_similarity_matrix, descending=True, dim=1
316 )
318 # get the ranks, so +1 to start counting ranks from 1
319 batch_modality_1_ranks = (
320 torch.argsort(batch_modality_1_argsort, dim=1) + 1
321 )
323 batch_target_indices = [
324 target_index[str(data_point.second)] for data_point in data_points
325 ]
327 batch_gt_ranks = batch_modality_1_ranks[
328 torch.arange(batch_similarity_matrix.shape[0]),
329 torch.tensor(batch_target_indices),
330 ]
331 ranks.extend(batch_gt_ranks.tolist())
333 store_embeddings(data_points, embedding_storage_mode)
335 ranks = np.array(ranks)
336 median_rank = np.median(ranks)
337 recall_at = {k: np.mean(ranks <= k) for k in self.recall_at_points}
339 results_header = ["Median rank"] + [
340 "Recall@top" + str(r) for r in self.recall_at_points
341 ]
342 results_header_str = "\t".join(results_header)
343 epoch_results = [str(median_rank)] + [
344 str(recall_at[k]) for k in self.recall_at_points
345 ]
346 epoch_results_str = "\t".join(epoch_results)
347 detailed_results = ", ".join(
348 [f"{h}={v}" for h, v in zip(results_header, epoch_results)]
349 )
351 validated_measure = sum(
352 [
353 recall_at[r] * w
354 for r, w in zip(self.recall_at_points, self.recall_at_points_weights)
355 ]
356 )
358 return (
359 Result(
360 validated_measure,
361 results_header_str,
362 epoch_results_str,
363 detailed_results,
364 ),
365 torch.tensor(0),
366 )
368 def _get_state_dict(self):
369 model_state = {
370 "state_dict": self.state_dict(),
371 "input_modality_0_embedding": self.source_embeddings,
372 "input_modality_1_embedding": self.target_embeddings,
373 "similarity_measure": self.similarity_measure,
374 "similarity_loss": self.similarity_loss,
375 "source_mapping": self.source_mapping,
376 "target_mapping": self.target_mapping,
377 "eval_device": self.eval_device,
378 "recall_at_points": self.recall_at_points,
379 "recall_at_points_weights": self.recall_at_points_weights,
380 }
381 return model_state
383 @staticmethod
384 def _init_model_with_state_dict(state):
385 # The conversion from old model's constructor interface
386 if "input_embeddings" in state:
387 state["input_modality_0_embedding"] = state["input_embeddings"][0]
388 state["input_modality_1_embedding"] = state["input_embeddings"][1]
389 model = SimilarityLearner(
390 source_embeddings=state["input_modality_0_embedding"],
391 target_embeddings=state["input_modality_1_embedding"],
392 source_mapping=state["source_mapping"],
393 target_mapping=state["target_mapping"],
394 similarity_measure=state["similarity_measure"],
395 similarity_loss=state["similarity_loss"],
396 eval_device=state["eval_device"],
397 recall_at_points=state["recall_at_points"],
398 recall_at_points_weights=state["recall_at_points_weights"],
399 )
401 model.load_state_dict(state["state_dict"])
402 return model