Coverage for flair/flair/datasets/text_image.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

53 statements  

1import logging 

2import os 

3import numpy as np 

4import json 

5import urllib 

6 

7from tqdm import tqdm 

8from pathlib import Path 

9from typing import List 

10 

11import torch.utils.data.dataloader 

12from torch.utils.data import Dataset 

13 

14from flair.data import ( 

15 Sentence, 

16 Corpus, 

17 FlairDataset, 

18 DataPair, 

19 Image, 

20) 

21from flair.file_utils import cached_path 

22 

23log = logging.getLogger("flair") 

24 

25 

26class FeideggerCorpus(Corpus): 

27 def __init__(self, **kwargs): 

28 dataset = "feidegger" 

29 

30 # cache Feidegger config file 

31 json_link = "https://raw.githubusercontent.com/zalandoresearch/feidegger/master/data/FEIDEGGER_release_1.1.json" 

32 json_local_path = cached_path(json_link, Path("datasets") / dataset) 

33 

34 # cache Feidegger images 

35 dataset_info = json.load(open(json_local_path, "r")) 

36 images_cache_folder = os.path.join(os.path.dirname(json_local_path), "images") 

37 if not os.path.isdir(images_cache_folder): 

38 os.mkdir(images_cache_folder) 

39 for image_info in tqdm(dataset_info): 

40 name = os.path.basename(image_info["url"]) 

41 filename = os.path.join(images_cache_folder, name) 

42 if not os.path.isfile(filename): 

43 urllib.request.urlretrieve(image_info["url"], filename) 

44 # replace image URL with local cached file 

45 image_info["url"] = filename 

46 

47 feidegger_dataset: Dataset = FeideggerDataset(dataset_info, **kwargs) 

48 

49 train_indices = list( 

50 np.where(np.in1d(feidegger_dataset.split, list(range(8))))[0] 

51 ) 

52 train = torch.utils.data.dataset.Subset(feidegger_dataset, train_indices) 

53 

54 dev_indices = list(np.where(np.in1d(feidegger_dataset.split, [8]))[0]) 

55 dev = torch.utils.data.dataset.Subset(feidegger_dataset, dev_indices) 

56 

57 test_indices = list(np.where(np.in1d(feidegger_dataset.split, [9]))[0]) 

58 test = torch.utils.data.dataset.Subset(feidegger_dataset, test_indices) 

59 

60 super(FeideggerCorpus, self).__init__(train, dev, test, name="feidegger") 

61 

62 

63class FeideggerDataset(FlairDataset): 

64 def __init__(self, dataset_info, in_memory: bool = True, **kwargs): 

65 super(FeideggerDataset, self).__init__() 

66 

67 self.data_points: List[DataPair] = [] 

68 self.split: List[int] = [] 

69 

70 preprocessor = lambda x: x 

71 if "lowercase" in kwargs and kwargs["lowercase"]: 

72 preprocessor = lambda x: x.lower() 

73 

74 for image_info in dataset_info: 

75 image = Image(imageURL=image_info["url"]) 

76 for caption in image_info["descriptions"]: 

77 # append Sentence-Image data point 

78 self.data_points.append( 

79 DataPair(Sentence(preprocessor(caption), use_tokenizer=True), image) 

80 ) 

81 self.split.append(int(image_info["split"])) 

82 

83 def __len__(self): 

84 return len(self.data_points) 

85 

86 def __getitem__(self, index: int = 0) -> DataPair: 

87 return self.data_points[index]