Coverage for flair/flair/embeddings/image.py: 28%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from abc import abstractmethod
2from typing import List
4import torch
5import torch.nn.functional as F
6from torch.nn import Parameter
8import flair
9from flair.data import Image
10from flair.embeddings.base import Embeddings
12import logging
14from torch.nn import Sequential, Linear, Conv2d, ReLU, MaxPool2d, Dropout2d
15from torch.nn import AdaptiveAvgPool2d, AdaptiveMaxPool2d
16from torch.nn import TransformerEncoderLayer, TransformerEncoder
19log = logging.getLogger("flair")
22class ImageEmbeddings(Embeddings):
23 @property
24 @abstractmethod
25 def embedding_length(self) -> int:
26 """Returns the length of the embedding vector."""
27 pass
29 @property
30 def embedding_type(self) -> str:
31 return "image-level"
34class IdentityImageEmbeddings(ImageEmbeddings):
35 def __init__(self, transforms):
36 import PIL as pythonimagelib
38 self.PIL = pythonimagelib
39 self.name = "Identity"
40 self.transforms = transforms
41 self.__embedding_length = None
42 self.static_embeddings = True
43 super().__init__()
45 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]:
46 for image in images:
47 image_data = self.PIL.Image.open(image.imageURL)
48 image_data.load()
49 image.set_embedding(self.name, self.transforms(image_data))
51 @property
52 def embedding_length(self) -> int:
53 return self.__embedding_length
55 def __str__(self):
56 return self.name
59class PrecomputedImageEmbeddings(ImageEmbeddings):
60 def __init__(self, url2tensor_dict, name):
61 self.url2tensor_dict = url2tensor_dict
62 self.name = name
63 self.__embedding_length = len(list(self.url2tensor_dict.values())[0])
64 self.static_embeddings = True
65 super().__init__()
67 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]:
68 for image in images:
69 if image.imageURL in self.url2tensor_dict:
70 image.set_embedding(self.name, self.url2tensor_dict[image.imageURL])
71 else:
72 image.set_embedding(
73 self.name, torch.zeros(self.__embedding_length, device=flair.device)
74 )
76 @property
77 def embedding_length(self) -> int:
78 return self.__embedding_length
80 def __str__(self):
81 return self.name
84class NetworkImageEmbeddings(ImageEmbeddings):
85 def __init__(self, name, pretrained=True, transforms=None):
86 super().__init__()
88 try:
89 import torchvision as torchvision
90 except ModuleNotFoundError:
91 log.warning("-" * 100)
92 log.warning('ATTENTION! The library "torchvision" is not installed!')
93 log.warning(
94 'To use convnets pretraned on ImageNet, please first install with "pip install torchvision"'
95 )
96 log.warning("-" * 100)
97 pass
99 model_info = {
100 "resnet50": (torchvision.models.resnet50, lambda x: list(x)[:-1], 2048),
101 "mobilenet_v2": (
102 torchvision.models.mobilenet_v2,
103 lambda x: list(x)[:-1] + [torch.nn.AdaptiveAvgPool2d((1, 1))],
104 1280,
105 ),
106 }
108 transforms = [] if transforms is None else transforms
109 transforms += [torchvision.transforms.ToTensor()]
110 if pretrained:
111 imagenet_mean = [0.485, 0.456, 0.406]
112 imagenet_std = [0.229, 0.224, 0.225]
113 transforms += [
114 torchvision.transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
115 ]
116 self.transforms = torchvision.transforms.Compose(transforms)
118 if name in model_info:
119 model_constructor = model_info[name][0]
120 model_features = model_info[name][1]
121 embedding_length = model_info[name][2]
123 net = model_constructor(pretrained=pretrained)
124 modules = model_features(net.children())
125 self.features = torch.nn.Sequential(*modules)
127 self.__embedding_length = embedding_length
129 self.name = name
130 else:
131 raise Exception(f"Image embeddings {name} not available.")
133 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]:
134 image_tensor = torch.stack([self.transforms(image.data) for image in images])
135 image_embeddings = self.features(image_tensor)
136 image_embeddings = (
137 image_embeddings.view(image_embeddings.shape[:2])
138 if image_embeddings.dim() == 4
139 else image_embeddings
140 )
141 if image_embeddings.dim() != 2:
142 raise Exception(
143 f"Unknown embedding shape of length {image_embeddings.dim()}"
144 )
145 for image_id, image in enumerate(images):
146 image.set_embedding(self.name, image_embeddings[image_id])
148 @property
149 def embedding_length(self) -> int:
150 return self.__embedding_length
152 def __str__(self):
153 return self.name
156class ConvTransformNetworkImageEmbeddings(ImageEmbeddings):
157 def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms):
158 super(ConvTransformNetworkImageEmbeddings, self).__init__()
160 adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d}
162 convnet_arch = (
163 []
164 if convnet_parms["dropout"][0] <= 0
165 else [Dropout2d(convnet_parms["dropout"][0])]
166 )
167 convnet_arch.extend(
168 [
169 Conv2d(
170 in_channels=feats_in,
171 out_channels=convnet_parms["n_feats_out"][0],
172 kernel_size=convnet_parms["kernel_sizes"][0],
173 padding=convnet_parms["kernel_sizes"][0][0] // 2,
174 stride=convnet_parms["strides"][0],
175 groups=convnet_parms["groups"][0],
176 ),
177 ReLU(),
178 ]
179 )
180 if "0" in convnet_parms["pool_layers_map"]:
181 convnet_arch.append(
182 MaxPool2d(kernel_size=convnet_parms["pool_layers_map"]["0"])
183 )
184 for layer_id, (kernel_size, n_in, n_out, groups, stride, dropout) in enumerate(
185 zip(
186 convnet_parms["kernel_sizes"][1:],
187 convnet_parms["n_feats_out"][:-1],
188 convnet_parms["n_feats_out"][1:],
189 convnet_parms["groups"][1:],
190 convnet_parms["strides"][1:],
191 convnet_parms["dropout"][1:],
192 )
193 ):
194 if dropout > 0:
195 convnet_arch.append(Dropout2d(dropout))
196 convnet_arch.append(
197 Conv2d(
198 in_channels=n_in,
199 out_channels=n_out,
200 kernel_size=kernel_size,
201 padding=kernel_size[0] // 2,
202 stride=stride,
203 groups=groups,
204 )
205 )
206 convnet_arch.append(ReLU())
207 if str(layer_id + 1) in convnet_parms["pool_layers_map"]:
208 convnet_arch.append(
209 MaxPool2d(
210 kernel_size=convnet_parms["pool_layers_map"][str(layer_id + 1)]
211 )
212 )
213 convnet_arch.append(
214 adaptive_pool_func_map[convnet_parms["adaptive_pool_func"]](
215 output_size=convnet_parms["output_size"]
216 )
217 )
218 self.conv_features = Sequential(*convnet_arch)
219 conv_feat_dim = convnet_parms["n_feats_out"][-1]
220 if posnet_parms is not None and transformer_parms is not None:
221 self.use_transformer = True
222 if posnet_parms["nonlinear"]:
223 posnet_arch = [
224 Linear(2, posnet_parms["n_hidden"]),
225 ReLU(),
226 Linear(posnet_parms["n_hidden"], conv_feat_dim),
227 ]
228 else:
229 posnet_arch = [Linear(2, conv_feat_dim)]
230 self.position_features = Sequential(*posnet_arch)
231 transformer_layer = TransformerEncoderLayer(
232 d_model=conv_feat_dim, **transformer_parms["transformer_encoder_parms"]
233 )
234 self.transformer = TransformerEncoder(
235 transformer_layer, num_layers=transformer_parms["n_blocks"]
236 )
237 # <cls> token initially set to 1/D, so it attends to all image features equally
238 self.cls_token = Parameter(torch.ones(conv_feat_dim, 1) / conv_feat_dim)
239 self._feat_dim = conv_feat_dim
240 else:
241 self.use_transformer = False
242 self._feat_dim = (
243 convnet_parms["output_size"][0]
244 * convnet_parms["output_size"][1]
245 * conv_feat_dim
246 )
248 def forward(self, x):
249 x = self.conv_features(x) # [b, d, h, w]
250 b, d, h, w = x.shape
251 if self.use_transformer:
252 # add positional encodings
253 y = torch.stack(
254 [
255 torch.cat([torch.arange(h).unsqueeze(1)] * w, dim=1),
256 torch.cat([torch.arange(w).unsqueeze(0)] * h, dim=0),
257 ]
258 ) # [2, h, w
259 y = y.view([2, h * w]).transpose(1, 0) # [h*w, 2]
260 y = y.type(torch.float32).to(flair.device)
261 y = (
262 self.position_features(y).transpose(1, 0).view([d, h, w])
263 ) # [h*w, d] => [d, h, w]
264 y = y.unsqueeze(dim=0) # [1, d, h, w]
265 x = x + y # [b, d, h, w] + [1, d, h, w] => [b, d, h, w]
266 # reshape the pixels into the sequence
267 x = x.view([b, d, h * w]) # [b, d, h*w]
268 # layer norm after convolution and positional encodings
269 x = F.layer_norm(x.permute([0, 2, 1]), (d,)).permute([0, 2, 1])
270 # add <cls> token
271 x = torch.cat(
272 [x, torch.stack([self.cls_token] * b)], dim=2
273 ) # [b, d, h*w+1]
274 # transformer requires input in the shape [h*w+1, b, d]
275 x = (
276 x.view([b * d, h * w + 1]).transpose(1, 0).view([h * w + 1, b, d])
277 ) # [b, d, h*w+1] => [b*d, h*w+1] => [h*w+1, b*d] => [h*w+1, b*d]
278 x = self.transformer(x) # [h*w+1, b, d]
279 # the output is an embedding of <cls> token
280 x = x[-1, :, :] # [b, d]
281 else:
282 x = x.view([-1, self._feat_dim])
283 x = F.layer_norm(x, (self._feat_dim,))
285 return x
287 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]:
288 image_tensor = torch.stack([image.data for image in images])
289 image_embeddings = self.forward(image_tensor)
290 for image_id, image in enumerate(images):
291 image.set_embedding(self.name, image_embeddings[image_id])
293 @property
294 def embedding_length(self):
295 return self._feat_dim
297 def __str__(self):
298 return self.name