Coverage for flair/flair/embeddings/image.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

158 statements  

1from abc import abstractmethod 

2from typing import List 

3 

4import torch 

5import torch.nn.functional as F 

6from torch.nn import Parameter 

7 

8import flair 

9from flair.data import Image 

10from flair.embeddings.base import Embeddings 

11 

12import logging 

13 

14from torch.nn import Sequential, Linear, Conv2d, ReLU, MaxPool2d, Dropout2d 

15from torch.nn import AdaptiveAvgPool2d, AdaptiveMaxPool2d 

16from torch.nn import TransformerEncoderLayer, TransformerEncoder 

17 

18 

19log = logging.getLogger("flair") 

20 

21 

22class ImageEmbeddings(Embeddings): 

23 @property 

24 @abstractmethod 

25 def embedding_length(self) -> int: 

26 """Returns the length of the embedding vector.""" 

27 pass 

28 

29 @property 

30 def embedding_type(self) -> str: 

31 return "image-level" 

32 

33 

34class IdentityImageEmbeddings(ImageEmbeddings): 

35 def __init__(self, transforms): 

36 import PIL as pythonimagelib 

37 

38 self.PIL = pythonimagelib 

39 self.name = "Identity" 

40 self.transforms = transforms 

41 self.__embedding_length = None 

42 self.static_embeddings = True 

43 super().__init__() 

44 

45 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]: 

46 for image in images: 

47 image_data = self.PIL.Image.open(image.imageURL) 

48 image_data.load() 

49 image.set_embedding(self.name, self.transforms(image_data)) 

50 

51 @property 

52 def embedding_length(self) -> int: 

53 return self.__embedding_length 

54 

55 def __str__(self): 

56 return self.name 

57 

58 

59class PrecomputedImageEmbeddings(ImageEmbeddings): 

60 def __init__(self, url2tensor_dict, name): 

61 self.url2tensor_dict = url2tensor_dict 

62 self.name = name 

63 self.__embedding_length = len(list(self.url2tensor_dict.values())[0]) 

64 self.static_embeddings = True 

65 super().__init__() 

66 

67 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]: 

68 for image in images: 

69 if image.imageURL in self.url2tensor_dict: 

70 image.set_embedding(self.name, self.url2tensor_dict[image.imageURL]) 

71 else: 

72 image.set_embedding( 

73 self.name, torch.zeros(self.__embedding_length, device=flair.device) 

74 ) 

75 

76 @property 

77 def embedding_length(self) -> int: 

78 return self.__embedding_length 

79 

80 def __str__(self): 

81 return self.name 

82 

83 

84class NetworkImageEmbeddings(ImageEmbeddings): 

85 def __init__(self, name, pretrained=True, transforms=None): 

86 super().__init__() 

87 

88 try: 

89 import torchvision as torchvision 

90 except ModuleNotFoundError: 

91 log.warning("-" * 100) 

92 log.warning('ATTENTION! The library "torchvision" is not installed!') 

93 log.warning( 

94 'To use convnets pretraned on ImageNet, please first install with "pip install torchvision"' 

95 ) 

96 log.warning("-" * 100) 

97 pass 

98 

99 model_info = { 

100 "resnet50": (torchvision.models.resnet50, lambda x: list(x)[:-1], 2048), 

101 "mobilenet_v2": ( 

102 torchvision.models.mobilenet_v2, 

103 lambda x: list(x)[:-1] + [torch.nn.AdaptiveAvgPool2d((1, 1))], 

104 1280, 

105 ), 

106 } 

107 

108 transforms = [] if transforms is None else transforms 

109 transforms += [torchvision.transforms.ToTensor()] 

110 if pretrained: 

111 imagenet_mean = [0.485, 0.456, 0.406] 

112 imagenet_std = [0.229, 0.224, 0.225] 

113 transforms += [ 

114 torchvision.transforms.Normalize(mean=imagenet_mean, std=imagenet_std) 

115 ] 

116 self.transforms = torchvision.transforms.Compose(transforms) 

117 

118 if name in model_info: 

119 model_constructor = model_info[name][0] 

120 model_features = model_info[name][1] 

121 embedding_length = model_info[name][2] 

122 

123 net = model_constructor(pretrained=pretrained) 

124 modules = model_features(net.children()) 

125 self.features = torch.nn.Sequential(*modules) 

126 

127 self.__embedding_length = embedding_length 

128 

129 self.name = name 

130 else: 

131 raise Exception(f"Image embeddings {name} not available.") 

132 

133 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]: 

134 image_tensor = torch.stack([self.transforms(image.data) for image in images]) 

135 image_embeddings = self.features(image_tensor) 

136 image_embeddings = ( 

137 image_embeddings.view(image_embeddings.shape[:2]) 

138 if image_embeddings.dim() == 4 

139 else image_embeddings 

140 ) 

141 if image_embeddings.dim() != 2: 

142 raise Exception( 

143 f"Unknown embedding shape of length {image_embeddings.dim()}" 

144 ) 

145 for image_id, image in enumerate(images): 

146 image.set_embedding(self.name, image_embeddings[image_id]) 

147 

148 @property 

149 def embedding_length(self) -> int: 

150 return self.__embedding_length 

151 

152 def __str__(self): 

153 return self.name 

154 

155 

156class ConvTransformNetworkImageEmbeddings(ImageEmbeddings): 

157 def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms): 

158 super(ConvTransformNetworkImageEmbeddings, self).__init__() 

159 

160 adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d} 

161 

162 convnet_arch = ( 

163 [] 

164 if convnet_parms["dropout"][0] <= 0 

165 else [Dropout2d(convnet_parms["dropout"][0])] 

166 ) 

167 convnet_arch.extend( 

168 [ 

169 Conv2d( 

170 in_channels=feats_in, 

171 out_channels=convnet_parms["n_feats_out"][0], 

172 kernel_size=convnet_parms["kernel_sizes"][0], 

173 padding=convnet_parms["kernel_sizes"][0][0] // 2, 

174 stride=convnet_parms["strides"][0], 

175 groups=convnet_parms["groups"][0], 

176 ), 

177 ReLU(), 

178 ] 

179 ) 

180 if "0" in convnet_parms["pool_layers_map"]: 

181 convnet_arch.append( 

182 MaxPool2d(kernel_size=convnet_parms["pool_layers_map"]["0"]) 

183 ) 

184 for layer_id, (kernel_size, n_in, n_out, groups, stride, dropout) in enumerate( 

185 zip( 

186 convnet_parms["kernel_sizes"][1:], 

187 convnet_parms["n_feats_out"][:-1], 

188 convnet_parms["n_feats_out"][1:], 

189 convnet_parms["groups"][1:], 

190 convnet_parms["strides"][1:], 

191 convnet_parms["dropout"][1:], 

192 ) 

193 ): 

194 if dropout > 0: 

195 convnet_arch.append(Dropout2d(dropout)) 

196 convnet_arch.append( 

197 Conv2d( 

198 in_channels=n_in, 

199 out_channels=n_out, 

200 kernel_size=kernel_size, 

201 padding=kernel_size[0] // 2, 

202 stride=stride, 

203 groups=groups, 

204 ) 

205 ) 

206 convnet_arch.append(ReLU()) 

207 if str(layer_id + 1) in convnet_parms["pool_layers_map"]: 

208 convnet_arch.append( 

209 MaxPool2d( 

210 kernel_size=convnet_parms["pool_layers_map"][str(layer_id + 1)] 

211 ) 

212 ) 

213 convnet_arch.append( 

214 adaptive_pool_func_map[convnet_parms["adaptive_pool_func"]]( 

215 output_size=convnet_parms["output_size"] 

216 ) 

217 ) 

218 self.conv_features = Sequential(*convnet_arch) 

219 conv_feat_dim = convnet_parms["n_feats_out"][-1] 

220 if posnet_parms is not None and transformer_parms is not None: 

221 self.use_transformer = True 

222 if posnet_parms["nonlinear"]: 

223 posnet_arch = [ 

224 Linear(2, posnet_parms["n_hidden"]), 

225 ReLU(), 

226 Linear(posnet_parms["n_hidden"], conv_feat_dim), 

227 ] 

228 else: 

229 posnet_arch = [Linear(2, conv_feat_dim)] 

230 self.position_features = Sequential(*posnet_arch) 

231 transformer_layer = TransformerEncoderLayer( 

232 d_model=conv_feat_dim, **transformer_parms["transformer_encoder_parms"] 

233 ) 

234 self.transformer = TransformerEncoder( 

235 transformer_layer, num_layers=transformer_parms["n_blocks"] 

236 ) 

237 # <cls> token initially set to 1/D, so it attends to all image features equally 

238 self.cls_token = Parameter(torch.ones(conv_feat_dim, 1) / conv_feat_dim) 

239 self._feat_dim = conv_feat_dim 

240 else: 

241 self.use_transformer = False 

242 self._feat_dim = ( 

243 convnet_parms["output_size"][0] 

244 * convnet_parms["output_size"][1] 

245 * conv_feat_dim 

246 ) 

247 

248 def forward(self, x): 

249 x = self.conv_features(x) # [b, d, h, w] 

250 b, d, h, w = x.shape 

251 if self.use_transformer: 

252 # add positional encodings 

253 y = torch.stack( 

254 [ 

255 torch.cat([torch.arange(h).unsqueeze(1)] * w, dim=1), 

256 torch.cat([torch.arange(w).unsqueeze(0)] * h, dim=0), 

257 ] 

258 ) # [2, h, w 

259 y = y.view([2, h * w]).transpose(1, 0) # [h*w, 2] 

260 y = y.type(torch.float32).to(flair.device) 

261 y = ( 

262 self.position_features(y).transpose(1, 0).view([d, h, w]) 

263 ) # [h*w, d] => [d, h, w] 

264 y = y.unsqueeze(dim=0) # [1, d, h, w] 

265 x = x + y # [b, d, h, w] + [1, d, h, w] => [b, d, h, w] 

266 # reshape the pixels into the sequence 

267 x = x.view([b, d, h * w]) # [b, d, h*w] 

268 # layer norm after convolution and positional encodings 

269 x = F.layer_norm(x.permute([0, 2, 1]), (d,)).permute([0, 2, 1]) 

270 # add <cls> token 

271 x = torch.cat( 

272 [x, torch.stack([self.cls_token] * b)], dim=2 

273 ) # [b, d, h*w+1] 

274 # transformer requires input in the shape [h*w+1, b, d] 

275 x = ( 

276 x.view([b * d, h * w + 1]).transpose(1, 0).view([h * w + 1, b, d]) 

277 ) # [b, d, h*w+1] => [b*d, h*w+1] => [h*w+1, b*d] => [h*w+1, b*d] 

278 x = self.transformer(x) # [h*w+1, b, d] 

279 # the output is an embedding of <cls> token 

280 x = x[-1, :, :] # [b, d] 

281 else: 

282 x = x.view([-1, self._feat_dim]) 

283 x = F.layer_norm(x, (self._feat_dim,)) 

284 

285 return x 

286 

287 def _add_embeddings_internal(self, images: List[Image]) -> List[Image]: 

288 image_tensor = torch.stack([image.data for image in images]) 

289 image_embeddings = self.forward(image_tensor) 

290 for image_id, image in enumerate(images): 

291 image.set_embedding(self.name, image_embeddings[image_id]) 

292 

293 @property 

294 def embedding_length(self): 

295 return self._feat_dim 

296 

297 def __str__(self): 

298 return self.name