Coverage for flair/flair/datasets/text_image.py: 36%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import os
3import numpy as np
4import json
5import urllib
7from tqdm import tqdm
8from pathlib import Path
9from typing import List
11import torch.utils.data.dataloader
12from torch.utils.data import Dataset
14from flair.data import (
15 Sentence,
16 Corpus,
17 FlairDataset,
18 DataPair,
19 Image,
20)
21from flair.file_utils import cached_path
23log = logging.getLogger("flair")
26class FeideggerCorpus(Corpus):
27 def __init__(self, **kwargs):
28 dataset = "feidegger"
30 # cache Feidegger config file
31 json_link = "https://raw.githubusercontent.com/zalandoresearch/feidegger/master/data/FEIDEGGER_release_1.1.json"
32 json_local_path = cached_path(json_link, Path("datasets") / dataset)
34 # cache Feidegger images
35 dataset_info = json.load(open(json_local_path, "r"))
36 images_cache_folder = os.path.join(os.path.dirname(json_local_path), "images")
37 if not os.path.isdir(images_cache_folder):
38 os.mkdir(images_cache_folder)
39 for image_info in tqdm(dataset_info):
40 name = os.path.basename(image_info["url"])
41 filename = os.path.join(images_cache_folder, name)
42 if not os.path.isfile(filename):
43 urllib.request.urlretrieve(image_info["url"], filename)
44 # replace image URL with local cached file
45 image_info["url"] = filename
47 feidegger_dataset: Dataset = FeideggerDataset(dataset_info, **kwargs)
49 train_indices = list(
50 np.where(np.in1d(feidegger_dataset.split, list(range(8))))[0]
51 )
52 train = torch.utils.data.dataset.Subset(feidegger_dataset, train_indices)
54 dev_indices = list(np.where(np.in1d(feidegger_dataset.split, [8]))[0])
55 dev = torch.utils.data.dataset.Subset(feidegger_dataset, dev_indices)
57 test_indices = list(np.where(np.in1d(feidegger_dataset.split, [9]))[0])
58 test = torch.utils.data.dataset.Subset(feidegger_dataset, test_indices)
60 super(FeideggerCorpus, self).__init__(train, dev, test, name="feidegger")
63class FeideggerDataset(FlairDataset):
64 def __init__(self, dataset_info, in_memory: bool = True, **kwargs):
65 super(FeideggerDataset, self).__init__()
67 self.data_points: List[DataPair] = []
68 self.split: List[int] = []
70 preprocessor = lambda x: x
71 if "lowercase" in kwargs and kwargs["lowercase"]:
72 preprocessor = lambda x: x.lower()
74 for image_info in dataset_info:
75 image = Image(imageURL=image_info["url"])
76 for caption in image_info["descriptions"]:
77 # append Sentence-Image data point
78 self.data_points.append(
79 DataPair(Sentence(preprocessor(caption), use_tokenizer=True), image)
80 )
81 self.split.append(int(image_info["split"]))
83 def __len__(self):
84 return len(self.data_points)
86 def __getitem__(self, index: int = 0) -> DataPair:
87 return self.data_points[index]