Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/datasets/base.py: 56%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from abc import abstractmethod
3from pathlib import Path
4from typing import List, Union, Callable
6import torch.utils.data.dataloader
7from torch.utils.data.dataset import Subset, ConcatDataset
9from flair.data import (
10 Sentence,
11 Token,
12 Tokenizer,
13 FlairDataset
14)
15from flair.tokenization import SegtokTokenizer, SpaceTokenizer
17log = logging.getLogger("flair")
20class DataLoader(torch.utils.data.dataloader.DataLoader):
21 def __init__(
22 self,
23 dataset,
24 batch_size=1,
25 shuffle=False,
26 sampler=None,
27 batch_sampler=None,
28 num_workers=8,
29 drop_last=False,
30 timeout=0,
31 worker_init_fn=None,
32 ):
34 # in certain cases, multi-CPU data loading makes no sense and slows
35 # everything down. For this reason, we detect if a dataset is in-memory:
36 # if so, num_workers is set to 0 for faster processing
37 flair_dataset = dataset
38 while True:
39 if type(flair_dataset) is Subset:
40 flair_dataset = flair_dataset.dataset
41 elif type(flair_dataset) is ConcatDataset:
42 flair_dataset = flair_dataset.datasets[0]
43 else:
44 break
46 if type(flair_dataset) is list:
47 num_workers = 0
48 elif isinstance(flair_dataset, FlairDataset) and flair_dataset.is_in_memory():
49 num_workers = 0
51 super(DataLoader, self).__init__(
52 dataset,
53 batch_size=batch_size,
54 shuffle=shuffle,
55 sampler=sampler,
56 batch_sampler=batch_sampler,
57 num_workers=num_workers,
58 collate_fn=list,
59 drop_last=drop_last,
60 timeout=timeout,
61 worker_init_fn=worker_init_fn,
62 )
65class SentenceDataset(FlairDataset):
66 """
67 A simple Dataset object to wrap a List of Sentence
68 """
70 def __init__(self, sentences: Union[Sentence, List[Sentence]]):
71 """
72 Instantiate SentenceDataset
73 :param sentences: Sentence or List of Sentence that make up SentenceDataset
74 """
75 # cast to list if necessary
76 if type(sentences) == Sentence:
77 sentences = [sentences]
78 self.sentences = sentences
80 def is_in_memory(self) -> bool:
81 return True
83 def __len__(self):
84 return len(self.sentences)
86 def __getitem__(self, index: int = 0) -> Sentence:
87 return self.sentences[index]
90class StringDataset(FlairDataset):
91 """
92 A Dataset taking string as input and returning Sentence during iteration
93 """
95 def __init__(
96 self,
97 texts: Union[str, List[str]],
98 use_tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
99 ):
100 """
101 Instantiate StringDataset
102 :param texts: a string or List of string that make up StringDataset
103 :param use_tokenizer: Custom tokenizer to use (default is SpaceTokenizer,
104 more advanced options are SegTokTokenizer to use segtok or SpacyTokenizer to use Spacy library models
105 if available). Check the code of subclasses of Tokenizer to implement your own (if you need it).
106 If instead of providing a function, this parameter is just set to True, SegTokTokenizer will be used.
107 """
108 # cast to list if necessary
109 if type(texts) == Sentence:
110 texts = [texts]
111 self.texts = texts
112 self.use_tokenizer = use_tokenizer
114 @abstractmethod
115 def is_in_memory(self) -> bool:
116 return True
118 def __len__(self):
119 return len(self.texts)
121 def __getitem__(self, index: int = 0) -> Sentence:
122 text = self.texts[index]
123 return Sentence(text, use_tokenizer=self.use_tokenizer)
126class MongoDataset(FlairDataset):
127 def __init__(
128 self,
129 query: str,
130 host: str,
131 port: int,
132 database: str,
133 collection: str,
134 text_field: str,
135 categories_field: List[str] = None,
136 max_tokens_per_doc: int = -1,
137 max_chars_per_doc: int = -1,
138 tokenizer: Tokenizer = SegtokTokenizer(),
139 in_memory: bool = True,
140 ):
141 """
142 Reads Mongo collections. Each collection should contain one document/text per item.
144 Each item should have the following format:
145 {
146 'Beskrivning': 'Abrahamsby. Gård i Gottröra sn, Långhundra hd, Stockholms län, nära Långsjön.',
147 'Län':'Stockholms län',
148 'Härad': 'Långhundra',
149 'Församling': 'Gottröra',
150 'Plats': 'Abrahamsby'
151 }
153 :param query: Query, e.g. {'Län': 'Stockholms län'}
154 :param host: Host, e.g. 'localhost',
155 :param port: Port, e.g. 27017
156 :param database: Database, e.g. 'rosenberg',
157 :param collection: Collection, e.g. 'book',
158 :param text_field: Text field, e.g. 'Beskrivning',
159 :param categories_field: List of category fields, e.g ['Län', 'Härad', 'Tingslag', 'Församling', 'Plats'],
160 :param max_tokens_per_doc: Takes at most this amount of tokens per document. If set to -1 all documents are taken as is.
161 :param max_tokens_per_doc: If set, truncates each Sentence to a maximum number of Tokens
162 :param max_chars_per_doc: If set, truncates each Sentence to a maximum number of chars
163 :param tokenizer: Custom tokenizer to use (default SegtokTokenizer)
164 :param in_memory: If True, keeps dataset as Sentences in memory, otherwise only keeps strings
165 :return: list of sentences
166 """
168 # first, check if pymongo is installed
169 try:
170 import pymongo
171 except ModuleNotFoundError:
172 log.warning("-" * 100)
173 log.warning('ATTENTION! The library "pymongo" is not installed!')
174 log.warning(
175 'To use MongoDataset, please first install with "pip install pymongo"'
176 )
177 log.warning("-" * 100)
178 pass
180 self.in_memory = in_memory
181 self.tokenizer = tokenizer
183 if self.in_memory:
184 self.sentences = []
185 else:
186 self.indices = []
188 self.total_sentence_count: int = 0
189 self.max_chars_per_doc = max_chars_per_doc
190 self.max_tokens_per_doc = max_tokens_per_doc
192 self.__connection = pymongo.MongoClient(host, port)
193 self.__cursor = self.__connection[database][collection]
195 self.text = text_field
196 self.categories = categories_field if categories_field is not None else []
198 start = 0
200 kwargs = lambda start: {"filter": query, "skip": start, "limit": 0}
202 if self.in_memory:
203 for document in self.__cursor.find(**kwargs(start)):
204 sentence = self._parse_document_to_sentence(
205 document[self.text],
206 [document[_] if _ in document else "" for _ in self.categories],
207 tokenizer,
208 )
209 if sentence is not None and len(sentence.tokens) > 0:
210 self.sentences.append(sentence)
211 self.total_sentence_count += 1
212 else:
213 self.indices = self.__cursor.find().distinct("_id")
214 self.total_sentence_count = self.__cursor.count_documents()
216 def _parse_document_to_sentence(
217 self, text: str, labels: List[str], tokenizer: Union[Callable[[str], List[Token]], Tokenizer]
218 ):
219 if self.max_chars_per_doc > 0:
220 text = text[: self.max_chars_per_doc]
222 if text and labels:
223 sentence = Sentence(text, labels=labels, use_tokenizer=tokenizer)
225 if self.max_tokens_per_doc > 0:
226 sentence.tokens = sentence.tokens[
227 : min(len(sentence), self.max_tokens_per_doc)
228 ]
230 return sentence
231 return None
233 def is_in_memory(self) -> bool:
234 return self.in_memory
236 def __len__(self):
237 return self.total_sentence_count
239 def __getitem__(self, index: int = 0) -> Sentence:
240 if self.in_memory:
241 return self.sentences[index]
242 else:
243 document = self.__cursor.find_one({"_id": index})
244 sentence = self._parse_document_to_sentence(
245 document[self.text],
246 [document[_] if _ in document else "" for _ in self.categories],
247 self.tokenizer,
248 )
249 return sentence
252def find_train_dev_test_files(data_folder, dev_file, test_file, train_file, autofind_splits=True):
253 if type(data_folder) == str:
254 data_folder: Path = Path(data_folder)
256 if train_file is not None:
257 train_file = data_folder / train_file
258 if test_file is not None:
259 test_file = data_folder / test_file
260 if dev_file is not None:
261 dev_file = data_folder / dev_file
263 suffixes_to_ignore = {".gz", ".swp"}
265 # automatically identify train / test / dev files
266 if train_file is None and autofind_splits:
267 for file in data_folder.iterdir():
268 file_name = file.name
269 if not suffixes_to_ignore.isdisjoint(file.suffixes):
270 continue
271 if "train" in file_name and not "54019" in file_name:
272 train_file = file
273 if "dev" in file_name:
274 dev_file = file
275 if "testa" in file_name:
276 dev_file = file
277 if "testb" in file_name:
278 test_file = file
280 # if no test file is found, take any file with 'test' in name
281 if test_file is None and autofind_splits:
282 for file in data_folder.iterdir():
283 file_name = file.name
284 if not suffixes_to_ignore.isdisjoint(file.suffixes):
285 continue
286 if "test" in file_name:
287 test_file = file
289 log.info("Reading data from {}".format(data_folder))
290 log.info("Train: {}".format(train_file))
291 log.info("Dev: {}".format(dev_file))
292 log.info("Test: {}".format(test_file))
294 return dev_file, test_file, train_file