Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/datasets/base.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

133 statements  

1import logging 

2from abc import abstractmethod 

3from pathlib import Path 

4from typing import List, Union, Callable 

5 

6import torch.utils.data.dataloader 

7from torch.utils.data.dataset import Subset, ConcatDataset 

8 

9from flair.data import ( 

10 Sentence, 

11 Token, 

12 Tokenizer, 

13 FlairDataset 

14) 

15from flair.tokenization import SegtokTokenizer, SpaceTokenizer 

16 

17log = logging.getLogger("flair") 

18 

19 

20class DataLoader(torch.utils.data.dataloader.DataLoader): 

21 def __init__( 

22 self, 

23 dataset, 

24 batch_size=1, 

25 shuffle=False, 

26 sampler=None, 

27 batch_sampler=None, 

28 num_workers=8, 

29 drop_last=False, 

30 timeout=0, 

31 worker_init_fn=None, 

32 ): 

33 

34 # in certain cases, multi-CPU data loading makes no sense and slows 

35 # everything down. For this reason, we detect if a dataset is in-memory: 

36 # if so, num_workers is set to 0 for faster processing 

37 flair_dataset = dataset 

38 while True: 

39 if type(flair_dataset) is Subset: 

40 flair_dataset = flair_dataset.dataset 

41 elif type(flair_dataset) is ConcatDataset: 

42 flair_dataset = flair_dataset.datasets[0] 

43 else: 

44 break 

45 

46 if type(flair_dataset) is list: 

47 num_workers = 0 

48 elif isinstance(flair_dataset, FlairDataset) and flair_dataset.is_in_memory(): 

49 num_workers = 0 

50 

51 super(DataLoader, self).__init__( 

52 dataset, 

53 batch_size=batch_size, 

54 shuffle=shuffle, 

55 sampler=sampler, 

56 batch_sampler=batch_sampler, 

57 num_workers=num_workers, 

58 collate_fn=list, 

59 drop_last=drop_last, 

60 timeout=timeout, 

61 worker_init_fn=worker_init_fn, 

62 ) 

63 

64 

65class SentenceDataset(FlairDataset): 

66 """ 

67 A simple Dataset object to wrap a List of Sentence 

68 """ 

69 

70 def __init__(self, sentences: Union[Sentence, List[Sentence]]): 

71 """ 

72 Instantiate SentenceDataset 

73 :param sentences: Sentence or List of Sentence that make up SentenceDataset 

74 """ 

75 # cast to list if necessary 

76 if type(sentences) == Sentence: 

77 sentences = [sentences] 

78 self.sentences = sentences 

79 

80 def is_in_memory(self) -> bool: 

81 return True 

82 

83 def __len__(self): 

84 return len(self.sentences) 

85 

86 def __getitem__(self, index: int = 0) -> Sentence: 

87 return self.sentences[index] 

88 

89 

90class StringDataset(FlairDataset): 

91 """ 

92 A Dataset taking string as input and returning Sentence during iteration 

93 """ 

94 

95 def __init__( 

96 self, 

97 texts: Union[str, List[str]], 

98 use_tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(), 

99 ): 

100 """ 

101 Instantiate StringDataset 

102 :param texts: a string or List of string that make up StringDataset 

103 :param use_tokenizer: Custom tokenizer to use (default is SpaceTokenizer, 

104 more advanced options are SegTokTokenizer to use segtok or SpacyTokenizer to use Spacy library models 

105 if available). Check the code of subclasses of Tokenizer to implement your own (if you need it). 

106 If instead of providing a function, this parameter is just set to True, SegTokTokenizer will be used. 

107 """ 

108 # cast to list if necessary 

109 if type(texts) == Sentence: 

110 texts = [texts] 

111 self.texts = texts 

112 self.use_tokenizer = use_tokenizer 

113 

114 @abstractmethod 

115 def is_in_memory(self) -> bool: 

116 return True 

117 

118 def __len__(self): 

119 return len(self.texts) 

120 

121 def __getitem__(self, index: int = 0) -> Sentence: 

122 text = self.texts[index] 

123 return Sentence(text, use_tokenizer=self.use_tokenizer) 

124 

125 

126class MongoDataset(FlairDataset): 

127 def __init__( 

128 self, 

129 query: str, 

130 host: str, 

131 port: int, 

132 database: str, 

133 collection: str, 

134 text_field: str, 

135 categories_field: List[str] = None, 

136 max_tokens_per_doc: int = -1, 

137 max_chars_per_doc: int = -1, 

138 tokenizer: Tokenizer = SegtokTokenizer(), 

139 in_memory: bool = True, 

140 ): 

141 """ 

142 Reads Mongo collections. Each collection should contain one document/text per item. 

143 

144 Each item should have the following format: 

145 { 

146 'Beskrivning': 'Abrahamsby. Gård i Gottröra sn, Långhundra hd, Stockholms län, nära Långsjön.', 

147 'Län':'Stockholms län', 

148 'Härad': 'Långhundra', 

149 'Församling': 'Gottröra', 

150 'Plats': 'Abrahamsby' 

151 } 

152 

153 :param query: Query, e.g. {'Län': 'Stockholms län'} 

154 :param host: Host, e.g. 'localhost', 

155 :param port: Port, e.g. 27017 

156 :param database: Database, e.g. 'rosenberg', 

157 :param collection: Collection, e.g. 'book', 

158 :param text_field: Text field, e.g. 'Beskrivning', 

159 :param categories_field: List of category fields, e.g ['Län', 'Härad', 'Tingslag', 'Församling', 'Plats'], 

160 :param max_tokens_per_doc: Takes at most this amount of tokens per document. If set to -1 all documents are taken as is. 

161 :param max_tokens_per_doc: If set, truncates each Sentence to a maximum number of Tokens 

162 :param max_chars_per_doc: If set, truncates each Sentence to a maximum number of chars 

163 :param tokenizer: Custom tokenizer to use (default SegtokTokenizer) 

164 :param in_memory: If True, keeps dataset as Sentences in memory, otherwise only keeps strings 

165 :return: list of sentences 

166 """ 

167 

168 # first, check if pymongo is installed 

169 try: 

170 import pymongo 

171 except ModuleNotFoundError: 

172 log.warning("-" * 100) 

173 log.warning('ATTENTION! The library "pymongo" is not installed!') 

174 log.warning( 

175 'To use MongoDataset, please first install with "pip install pymongo"' 

176 ) 

177 log.warning("-" * 100) 

178 pass 

179 

180 self.in_memory = in_memory 

181 self.tokenizer = tokenizer 

182 

183 if self.in_memory: 

184 self.sentences = [] 

185 else: 

186 self.indices = [] 

187 

188 self.total_sentence_count: int = 0 

189 self.max_chars_per_doc = max_chars_per_doc 

190 self.max_tokens_per_doc = max_tokens_per_doc 

191 

192 self.__connection = pymongo.MongoClient(host, port) 

193 self.__cursor = self.__connection[database][collection] 

194 

195 self.text = text_field 

196 self.categories = categories_field if categories_field is not None else [] 

197 

198 start = 0 

199 

200 kwargs = lambda start: {"filter": query, "skip": start, "limit": 0} 

201 

202 if self.in_memory: 

203 for document in self.__cursor.find(**kwargs(start)): 

204 sentence = self._parse_document_to_sentence( 

205 document[self.text], 

206 [document[_] if _ in document else "" for _ in self.categories], 

207 tokenizer, 

208 ) 

209 if sentence is not None and len(sentence.tokens) > 0: 

210 self.sentences.append(sentence) 

211 self.total_sentence_count += 1 

212 else: 

213 self.indices = self.__cursor.find().distinct("_id") 

214 self.total_sentence_count = self.__cursor.count_documents() 

215 

216 def _parse_document_to_sentence( 

217 self, text: str, labels: List[str], tokenizer: Union[Callable[[str], List[Token]], Tokenizer] 

218 ): 

219 if self.max_chars_per_doc > 0: 

220 text = text[: self.max_chars_per_doc] 

221 

222 if text and labels: 

223 sentence = Sentence(text, labels=labels, use_tokenizer=tokenizer) 

224 

225 if self.max_tokens_per_doc > 0: 

226 sentence.tokens = sentence.tokens[ 

227 : min(len(sentence), self.max_tokens_per_doc) 

228 ] 

229 

230 return sentence 

231 return None 

232 

233 def is_in_memory(self) -> bool: 

234 return self.in_memory 

235 

236 def __len__(self): 

237 return self.total_sentence_count 

238 

239 def __getitem__(self, index: int = 0) -> Sentence: 

240 if self.in_memory: 

241 return self.sentences[index] 

242 else: 

243 document = self.__cursor.find_one({"_id": index}) 

244 sentence = self._parse_document_to_sentence( 

245 document[self.text], 

246 [document[_] if _ in document else "" for _ in self.categories], 

247 self.tokenizer, 

248 ) 

249 return sentence 

250 

251 

252def find_train_dev_test_files(data_folder, dev_file, test_file, train_file, autofind_splits=True): 

253 if type(data_folder) == str: 

254 data_folder: Path = Path(data_folder) 

255 

256 if train_file is not None: 

257 train_file = data_folder / train_file 

258 if test_file is not None: 

259 test_file = data_folder / test_file 

260 if dev_file is not None: 

261 dev_file = data_folder / dev_file 

262 

263 suffixes_to_ignore = {".gz", ".swp"} 

264 

265 # automatically identify train / test / dev files 

266 if train_file is None and autofind_splits: 

267 for file in data_folder.iterdir(): 

268 file_name = file.name 

269 if not suffixes_to_ignore.isdisjoint(file.suffixes): 

270 continue 

271 if "train" in file_name and not "54019" in file_name: 

272 train_file = file 

273 if "dev" in file_name: 

274 dev_file = file 

275 if "testa" in file_name: 

276 dev_file = file 

277 if "testb" in file_name: 

278 test_file = file 

279 

280 # if no test file is found, take any file with 'test' in name 

281 if test_file is None and autofind_splits: 

282 for file in data_folder.iterdir(): 

283 file_name = file.name 

284 if not suffixes_to_ignore.isdisjoint(file.suffixes): 

285 continue 

286 if "test" in file_name: 

287 test_file = file 

288 

289 log.info("Reading data from {}".format(data_folder)) 

290 log.info("Train: {}".format(train_file)) 

291 log.info("Dev: {}".format(dev_file)) 

292 log.info("Test: {}".format(test_file)) 

293 

294 return dev_file, test_file, train_file