Coverage for flair/flair/data.py: 25%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

950 statements  

1import logging 

2import re 

3from abc import abstractmethod, ABC 

4from collections import Counter 

5from collections import defaultdict 

6from operator import itemgetter 

7from typing import List, Dict, Union, Optional 

8 

9from deprecated import deprecated 

10import torch 

11from torch.utils.data import Dataset 

12from torch.utils.data.dataset import ConcatDataset, Subset 

13 

14import flair 

15from flair.file_utils import Tqdm 

16 

17log = logging.getLogger("flair") 

18 

19 

20class Dictionary: 

21 """ 

22 This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings. 

23 """ 

24 

25 def __init__(self, add_unk=True): 

26 # init dictionaries 

27 self.item2idx: Dict[str, int] = {} 

28 self.idx2item: List[str] = [] 

29 self.add_unk = add_unk 

30 # in order to deal with unknown tokens, add <unk> 

31 if add_unk: 

32 self.add_item("<unk>") 

33 

34 def remove_item(self, item: str): 

35 

36 item = item.encode("utf-8") 

37 if item in self.item2idx: 

38 self.idx2item.remove(item) 

39 del self.item2idx[item] 

40 

41 def add_item(self, item: str) -> int: 

42 """ 

43 add string - if already in dictionary returns its ID. if not in dictionary, it will get a new ID. 

44 :param item: a string for which to assign an id. 

45 :return: ID of string 

46 """ 

47 item = item.encode("utf-8") 

48 if item not in self.item2idx: 

49 self.idx2item.append(item) 

50 self.item2idx[item] = len(self.idx2item) - 1 

51 return self.item2idx[item] 

52 

53 def get_idx_for_item(self, item: str) -> int: 

54 """ 

55 returns the ID of the string, otherwise 0 

56 :param item: string for which ID is requested 

57 :return: ID of string, otherwise 0 

58 """ 

59 item_encoded = item.encode("utf-8") 

60 if item_encoded in self.item2idx.keys(): 

61 return self.item2idx[item_encoded] 

62 elif self.add_unk: 

63 return 0 

64 else: 

65 log.error(f"The string '{item}' is not in dictionary! Dictionary contains only: {self.get_items()}") 

66 log.error( 

67 "You can create a Dictionary that handles unknown items with an <unk>-key by setting add_unk = True in the construction.") 

68 raise IndexError 

69 

70 def get_idx_for_items(self, items: List[str]) -> List[int]: 

71 """ 

72 returns the IDs for each item of the list of string, otherwise 0 if not found 

73 :param items: List of string for which IDs are requested 

74 :return: List of ID of strings 

75 """ 

76 if not hasattr(self, "item2idx_not_encoded"): 

77 d = dict( 

78 [(key.decode("UTF-8"), value) for key, value in self.item2idx.items()] 

79 ) 

80 self.item2idx_not_encoded = defaultdict(int, d) 

81 

82 if not items: 

83 return [] 

84 results = itemgetter(*items)(self.item2idx_not_encoded) 

85 if isinstance(results, int): 

86 return [results] 

87 return list(results) 

88 

89 def get_items(self) -> List[str]: 

90 items = [] 

91 for item in self.idx2item: 

92 items.append(item.decode("UTF-8")) 

93 return items 

94 

95 def __len__(self) -> int: 

96 return len(self.idx2item) 

97 

98 def get_item_for_index(self, idx): 

99 return self.idx2item[idx].decode("UTF-8") 

100 

101 def save(self, savefile): 

102 import pickle 

103 

104 with open(savefile, "wb") as f: 

105 mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx} 

106 pickle.dump(mappings, f) 

107 

108 def __setstate__(self, d): 

109 self.__dict__ = d 

110 # set 'add_unk' if the dictionary was created with a version of Flair older than 0.9 

111 if 'add_unk' not in self.__dict__.keys(): 

112 self.__dict__['add_unk'] = True if b'<unk>' in self.__dict__['idx2item'] else False 

113 

114 @classmethod 

115 def load_from_file(cls, filename: str): 

116 import pickle 

117 

118 f = open(filename, "rb") 

119 mappings = pickle.load(f, encoding="latin1") 

120 idx2item = mappings["idx2item"] 

121 item2idx = mappings["item2idx"] 

122 f.close() 

123 

124 # set 'add_unk' depending on whether <unk> is a key 

125 add_unk = True if b'<unk>' in idx2item else False 

126 

127 dictionary: Dictionary = Dictionary(add_unk=add_unk) 

128 dictionary.item2idx = item2idx 

129 dictionary.idx2item = idx2item 

130 return dictionary 

131 

132 @classmethod 

133 def load(cls, name: str): 

134 from flair.file_utils import cached_path 

135 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/characters" 

136 if name == "chars" or name == "common-chars": 

137 char_dict = cached_path(f"{hu_path}/common_characters", cache_dir="datasets") 

138 return Dictionary.load_from_file(char_dict) 

139 

140 if name == "chars-large" or name == "common-chars-large": 

141 char_dict = cached_path(f"{hu_path}/common_characters_large", cache_dir="datasets") 

142 return Dictionary.load_from_file(char_dict) 

143 

144 if name == "chars-xl" or name == "common-chars-xl": 

145 char_dict = cached_path(f"{hu_path}/common_characters_xl", cache_dir="datasets") 

146 return Dictionary.load_from_file(char_dict) 

147 

148 return Dictionary.load_from_file(name) 

149 

150 def __str__(self): 

151 tags = ', '.join(self.get_item_for_index(i) for i in range(min(len(self), 50))) 

152 return f"Dictionary with {len(self)} tags: {tags}" 

153 

154 

155class Label: 

156 """ 

157 This class represents a label. Each label has a value and optionally a confidence score. The 

158 score needs to be between 0.0 and 1.0. Default value for the score is 1.0. 

159 """ 

160 

161 def __init__(self, value: str, score: float = 1.0): 

162 self._value = value 

163 self._score = score 

164 super().__init__() 

165 

166 def set_value(self, value: str, score: float = 1.0): 

167 self.value = value 

168 self.score = score 

169 

170 def spawn(self, value: str, score: float = 1.0): 

171 return Label(value, score) 

172 

173 @property 

174 def value(self): 

175 return self._value 

176 

177 @value.setter 

178 def value(self, value): 

179 if not value and value != "": 

180 raise ValueError( 

181 "Incorrect label value provided. Label value needs to be set." 

182 ) 

183 else: 

184 self._value = value 

185 

186 @property 

187 def score(self): 

188 return self._score 

189 

190 @score.setter 

191 def score(self, score): 

192 self._score = score 

193 

194 def to_dict(self): 

195 return {"value": self.value, "confidence": self.score} 

196 

197 def __str__(self): 

198 return f"{self._value} ({round(self._score, 4)})" 

199 

200 def __repr__(self): 

201 return f"{self._value} ({round(self._score, 4)})" 

202 

203 def __eq__(self, other): 

204 return self.value == other.value and self.score == other.score 

205 

206 @property 

207 def identifier(self): 

208 return "" 

209 

210 

211class SpanLabel(Label): 

212 def __init__(self, span, value: str, score: float = 1.0): 

213 super().__init__(value, score) 

214 self.span = span 

215 

216 def spawn(self, value: str, score: float = 1.0): 

217 return SpanLabel(self.span, value, score) 

218 

219 def __str__(self): 

220 return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})" 

221 

222 def __repr__(self): 

223 return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})" 

224 

225 def __len__(self): 

226 return len(self.span) 

227 

228 def __eq__(self, other): 

229 return self.value == other.value and self.score == other.score and self.span.id_text == other.span.id_text 

230 

231 @property 

232 def identifier(self): 

233 return f"{self.span.id_text}" 

234 

235 

236class RelationLabel(Label): 

237 def __init__(self, head, tail, value: str, score: float = 1.0): 

238 super().__init__(value, score) 

239 self.head = head 

240 self.tail = tail 

241 

242 def spawn(self, value: str, score: float = 1.0): 

243 return RelationLabel(self.head, self.tail, value, score) 

244 

245 def __str__(self): 

246 return f"{self._value} [{self.head.id_text} -> {self.tail.id_text}] ({round(self._score, 4)})" 

247 

248 def __repr__(self): 

249 return f"{self._value} from {self.head.id_text} -> {self.tail.id_text} ({round(self._score, 4)})" 

250 

251 def __len__(self): 

252 return len(self.head) + len(self.tail) 

253 

254 def __eq__(self, other): 

255 return self.value == other.value \ 

256 and self.score == other.score \ 

257 and self.head.id_text == other.head.id_text \ 

258 and self.tail.id_text == other.tail.id_text 

259 

260 @property 

261 def identifier(self): 

262 return f"{self.head.id_text} -> {self.tail.id_text}" 

263 

264 

265class DataPoint: 

266 """ 

267 This is the parent class of all data points in Flair (including Token, Sentence, Image, etc.). Each DataPoint 

268 must be embeddable (hence the abstract property embedding() and methods to() and clear_embeddings()). Also, 

269 each DataPoint may have Labels in several layers of annotation (hence the functions add_label(), get_labels() 

270 and the property 'label') 

271 """ 

272 

273 def __init__(self): 

274 self.annotation_layers = {} 

275 

276 @property 

277 @abstractmethod 

278 def embedding(self): 

279 pass 

280 

281 @abstractmethod 

282 def to(self, device: str, pin_memory: bool = False): 

283 pass 

284 

285 @abstractmethod 

286 def clear_embeddings(self, embedding_names: List[str] = None): 

287 pass 

288 

289 def add_label(self, typename: str, value: str, score: float = 1.): 

290 

291 if typename not in self.annotation_layers: 

292 self.annotation_layers[typename] = [Label(value, score)] 

293 else: 

294 self.annotation_layers[typename].append(Label(value, score)) 

295 

296 return self 

297 

298 def add_complex_label(self, typename: str, label: Label): 

299 

300 if typename in self.annotation_layers and label in self.annotation_layers[typename]: 

301 return self 

302 

303 if typename not in self.annotation_layers: 

304 self.annotation_layers[typename] = [label] 

305 else: 

306 self.annotation_layers[typename].append(label) 

307 

308 return self 

309 

310 def set_label(self, typename: str, value: str, score: float = 1.): 

311 self.annotation_layers[typename] = [Label(value, score)] 

312 return self 

313 

314 def remove_labels(self, typename: str): 

315 if typename in self.annotation_layers.keys(): 

316 del self.annotation_layers[typename] 

317 

318 def get_labels(self, typename: str = None): 

319 if typename is None: 

320 return self.labels 

321 

322 return self.annotation_layers[typename] if typename in self.annotation_layers else [] 

323 

324 @property 

325 def labels(self) -> List[Label]: 

326 all_labels = [] 

327 for key in self.annotation_layers.keys(): 

328 all_labels.extend(self.annotation_layers[key]) 

329 return all_labels 

330 

331 

332class DataPair(DataPoint): 

333 def __init__(self, first: DataPoint, second: DataPoint): 

334 super().__init__() 

335 self.first = first 

336 self.second = second 

337 

338 def to(self, device: str, pin_memory: bool = False): 

339 self.first.to(device, pin_memory) 

340 self.second.to(device, pin_memory) 

341 

342 def clear_embeddings(self, embedding_names: List[str] = None): 

343 self.first.clear_embeddings(embedding_names) 

344 self.second.clear_embeddings(embedding_names) 

345 

346 @property 

347 def embedding(self): 

348 return torch.cat([self.first.embedding, self.second.embedding]) 

349 

350 def __str__(self): 

351 return f"DataPair:\n − First {self.first}\n − Second {self.second}\n − Labels: {self.labels}" 

352 

353 def to_plain_string(self): 

354 return f"DataPair: First {self.first} || Second {self.second}" 

355 

356 def to_original_text(self): 

357 return f"{self.first.to_original_text()} || {self.second.to_original_text()}" 

358 

359 def __len__(self): 

360 return len(self.first) + len(self.second) 

361 

362 

363class Token(DataPoint): 

364 """ 

365 This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point 

366 to its head in a dependency tree. 

367 """ 

368 

369 def __init__( 

370 self, 

371 text: str, 

372 idx: int = None, 

373 head_id: int = None, 

374 whitespace_after: bool = True, 

375 start_position: int = None, 

376 ): 

377 super().__init__() 

378 

379 self.text: str = text 

380 self.idx: int = idx 

381 self.head_id: int = head_id 

382 self.whitespace_after: bool = whitespace_after 

383 

384 self.start_pos = start_position 

385 self.end_pos = ( 

386 start_position + len(text) if start_position is not None else None 

387 ) 

388 

389 self.sentence: Sentence = None 

390 self._embeddings: Dict = {} 

391 self.tags_proba_dist: Dict[str, List[Label]] = {} 

392 

393 def add_tag_label(self, tag_type: str, tag: Label): 

394 self.set_label(tag_type, tag.value, tag.score) 

395 

396 def add_tags_proba_dist(self, tag_type: str, tags: List[Label]): 

397 self.tags_proba_dist[tag_type] = tags 

398 

399 def add_tag(self, tag_type: str, tag_value: str, confidence=1.0): 

400 self.set_label(tag_type, tag_value, confidence) 

401 

402 def get_tag(self, label_type): 

403 if len(self.get_labels(label_type)) == 0: return Label('') 

404 return self.get_labels(label_type)[0] 

405 

406 def get_tags_proba_dist(self, tag_type: str) -> List[Label]: 

407 if tag_type in self.tags_proba_dist: 

408 return self.tags_proba_dist[tag_type] 

409 return [] 

410 

411 def get_head(self): 

412 return self.sentence.get_token(self.head_id) 

413 

414 def set_embedding(self, name: str, vector: torch.tensor): 

415 device = flair.device 

416 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0: 

417 device = next(iter(self._embeddings.values())).device 

418 if device != vector.device: 

419 vector = vector.to(device) 

420 self._embeddings[name] = vector 

421 

422 def to(self, device: str, pin_memory: bool = False): 

423 for name, vector in self._embeddings.items(): 

424 if str(vector.device) != str(device): 

425 if pin_memory: 

426 self._embeddings[name] = vector.to( 

427 device, non_blocking=True 

428 ).pin_memory() 

429 else: 

430 self._embeddings[name] = vector.to(device, non_blocking=True) 

431 

432 def clear_embeddings(self, embedding_names: List[str] = None): 

433 if embedding_names is None: 

434 self._embeddings: Dict = {} 

435 else: 

436 for name in embedding_names: 

437 if name in self._embeddings.keys(): 

438 del self._embeddings[name] 

439 

440 def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> torch.tensor: 

441 embeddings = [] 

442 for embed in sorted(self._embeddings.keys()): 

443 if embedding_names and embed not in embedding_names: continue 

444 embed = self._embeddings[embed].to(flair.device) 

445 if (flair.embedding_storage_mode == "cpu") and embed.device != flair.device: 

446 embed = embed.to(flair.device) 

447 embeddings.append(embed) 

448 return embeddings 

449 

450 def get_embedding(self, names: Optional[List[str]] = None) -> torch.tensor: 

451 embeddings = self.get_each_embedding(names) 

452 

453 if embeddings: 

454 return torch.cat(embeddings, dim=0) 

455 

456 return torch.tensor([], device=flair.device) 

457 

458 @property 

459 def start_position(self) -> int: 

460 return self.start_pos 

461 

462 @property 

463 def end_position(self) -> int: 

464 return self.end_pos 

465 

466 @property 

467 def embedding(self): 

468 return self.get_embedding() 

469 

470 def __str__(self) -> str: 

471 return ( 

472 "Token: {} {}".format(self.idx, self.text) 

473 if self.idx is not None 

474 else "Token: {}".format(self.text) 

475 ) 

476 

477 def __repr__(self) -> str: 

478 return ( 

479 "Token: {} {}".format(self.idx, self.text) 

480 if self.idx is not None 

481 else "Token: {}".format(self.text) 

482 ) 

483 

484 

485class Span(DataPoint): 

486 """ 

487 This class represents one textual span consisting of Tokens. 

488 """ 

489 

490 def __init__(self, tokens: List[Token]): 

491 

492 super().__init__() 

493 

494 self.tokens = tokens 

495 self.start_pos = None 

496 self.end_pos = None 

497 

498 if tokens: 

499 self.start_pos = tokens[0].start_position 

500 self.end_pos = tokens[len(tokens) - 1].end_position 

501 

502 @property 

503 def text(self) -> str: 

504 return " ".join([t.text for t in self.tokens]) 

505 

506 def to_original_text(self) -> str: 

507 pos = self.tokens[0].start_pos 

508 if pos is None: 

509 return " ".join([t.text for t in self.tokens]) 

510 str = "" 

511 for t in self.tokens: 

512 while t.start_pos > pos: 

513 str += " " 

514 pos += 1 

515 

516 str += t.text 

517 pos += len(t.text) 

518 

519 return str 

520 

521 def to_plain_string(self): 

522 plain = "" 

523 for token in self.tokens: 

524 plain += token.text 

525 if token.whitespace_after: 

526 plain += " " 

527 return plain.rstrip() 

528 

529 def to_dict(self): 

530 return { 

531 "text": self.to_original_text(), 

532 "start_pos": self.start_pos, 

533 "end_pos": self.end_pos, 

534 "labels": self.labels, 

535 } 

536 

537 def __str__(self) -> str: 

538 ids = ",".join([str(t.idx) for t in self.tokens]) 

539 label_string = " ".join([str(label) for label in self.labels]) 

540 labels = f' [− Labels: {label_string}]' if self.labels else "" 

541 return ( 

542 'Span [{}]: "{}"{}'.format(ids, self.text, labels) 

543 ) 

544 

545 @property 

546 def id_text(self) -> str: 

547 return f"{' '.join([t.text for t in self.tokens])} ({','.join([str(t.idx) for t in self.tokens])})" 

548 

549 def __repr__(self) -> str: 

550 ids = ",".join([str(t.idx) for t in self.tokens]) 

551 return ( 

552 '<{}-span ({}): "{}">'.format(self.tag, ids, self.text) 

553 if len(self.labels) > 0 

554 else '<span ({}): "{}">'.format(ids, self.text) 

555 ) 

556 

557 def __getitem__(self, idx: int) -> Token: 

558 return self.tokens[idx] 

559 

560 def __iter__(self): 

561 return iter(self.tokens) 

562 

563 def __len__(self) -> int: 

564 return len(self.tokens) 

565 

566 @property 

567 def tag(self): 

568 return self.labels[0].value 

569 

570 @property 

571 def score(self): 

572 return self.labels[0].score 

573 

574 @property 

575 def position_string(self): 

576 return '-'.join([str(token.idx) for token in self]) 

577 

578 

579class Tokenizer(ABC): 

580 r"""An abstract class representing a :class:`Tokenizer`. 

581 

582 Tokenizers are used to represent algorithms and models to split plain text into 

583 individual tokens / words. All subclasses should overwrite :meth:`tokenize`, which 

584 splits the given plain text into tokens. Moreover, subclasses may overwrite 

585 :meth:`name`, returning a unique identifier representing the tokenizer's 

586 configuration. 

587 """ 

588 

589 @abstractmethod 

590 def tokenize(self, text: str) -> List[Token]: 

591 raise NotImplementedError() 

592 

593 @property 

594 def name(self) -> str: 

595 return self.__class__.__name__ 

596 

597 

598class Sentence(DataPoint): 

599 """ 

600 A Sentence is a list of tokens and is used to represent a sentence or text fragment. 

601 """ 

602 

603 def __init__( 

604 self, 

605 text: Union[str, List[str]] = None, 

606 use_tokenizer: Union[bool, Tokenizer] = True, 

607 language_code: str = None, 

608 start_position: int = None 

609 ): 

610 """ 

611 Class to hold all meta related to a text (tokens, predictions, language code, ...) 

612 :param text: original string (sentence), or a list of string tokens (words) 

613 :param use_tokenizer: a custom tokenizer (default is :class:`SpaceTokenizer`) 

614 more advanced options are :class:`SegTokTokenizer` to use segtok or :class:`SpacyTokenizer` 

615 to use Spacy library if available). Check the implementations of abstract class Tokenizer or 

616 implement your own subclass (if you need it). If instead of providing a Tokenizer, this parameter 

617 is just set to True (deprecated), :class:`SegtokTokenizer` will be used. 

618 :param language_code: Language of the sentence 

619 :param start_position: Start char offset of the sentence in the superordinate document 

620 """ 

621 super().__init__() 

622 

623 self.tokens: List[Token] = [] 

624 

625 self._embeddings: Dict = {} 

626 

627 self.language_code: str = language_code 

628 

629 self.start_pos = start_position 

630 self.end_pos = ( 

631 start_position + len(text) if start_position is not None else None 

632 ) 

633 

634 if isinstance(use_tokenizer, Tokenizer): 

635 tokenizer = use_tokenizer 

636 elif hasattr(use_tokenizer, "__call__"): 

637 from flair.tokenization import TokenizerWrapper 

638 tokenizer = TokenizerWrapper(use_tokenizer) 

639 elif type(use_tokenizer) == bool: 

640 from flair.tokenization import SegtokTokenizer, SpaceTokenizer 

641 tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer() 

642 else: 

643 raise AssertionError("Unexpected type of parameter 'use_tokenizer'. " + 

644 "Parameter should be bool, Callable[[str], List[Token]] (deprecated), Tokenizer") 

645 

646 # if text is passed, instantiate sentence with tokens (words) 

647 if text is not None: 

648 if isinstance(text, (list, tuple)): 

649 [self.add_token(self._restore_windows_1252_characters(token)) 

650 for token in text] 

651 else: 

652 text = self._restore_windows_1252_characters(text) 

653 [self.add_token(token) for token in tokenizer.tokenize(text)] 

654 

655 # log a warning if the dataset is empty 

656 if text == "": 

657 log.warning( 

658 "Warning: An empty Sentence was created! Are there empty strings in your dataset?" 

659 ) 

660 

661 self.tokenized = None 

662 

663 # some sentences represent a document boundary (but most do not) 

664 self.is_document_boundary: bool = False 

665 

666 def get_token(self, token_id: int) -> Token: 

667 for token in self.tokens: 

668 if token.idx == token_id: 

669 return token 

670 

671 def add_token(self, token: Union[Token, str]): 

672 

673 if type(token) is str: 

674 token = Token(token) 

675 

676 token.text = token.text.replace('\u200c', '') 

677 token.text = token.text.replace('\u200b', '') 

678 token.text = token.text.replace('\ufe0f', '') 

679 token.text = token.text.replace('\ufeff', '') 

680 

681 # data with zero-width characters cannot be handled 

682 if token.text == '': 

683 return 

684 

685 self.tokens.append(token) 

686 

687 # set token idx if not set 

688 token.sentence = self 

689 if token.idx is None: 

690 token.idx = len(self.tokens) 

691 

692 def get_label_names(self): 

693 label_names = [] 

694 for label in self.labels: 

695 label_names.append(label.value) 

696 return label_names 

697 

698 def _add_spans_internal(self, spans: List[Span], label_type: str, min_score): 

699 

700 current_span = [] 

701 

702 tags = defaultdict(lambda: 0.0) 

703 

704 previous_tag_value: str = "O" 

705 for token in self: 

706 

707 tag: Label = token.get_tag(label_type) 

708 tag_value = tag.value 

709 

710 # non-set tags are OUT tags 

711 if tag_value == "" or tag_value == "O" or tag_value == "_": 

712 tag_value = "O-" 

713 

714 # anything that is not a BIOES tag is a SINGLE tag 

715 if tag_value[0:2] not in ["B-", "I-", "O-", "E-", "S-"]: 

716 tag_value = "S-" + tag_value 

717 

718 # anything that is not OUT is IN 

719 in_span = False 

720 if tag_value[0:2] not in ["O-"]: 

721 in_span = True 

722 

723 # single and begin tags start a new span 

724 starts_new_span = False 

725 if tag_value[0:2] in ["B-", "S-"]: 

726 starts_new_span = True 

727 

728 if ( 

729 previous_tag_value[0:2] in ["S-"] 

730 and previous_tag_value[2:] != tag_value[2:] 

731 and in_span 

732 ): 

733 starts_new_span = True 

734 

735 if (starts_new_span or not in_span) and len(current_span) > 0: 

736 scores = [t.get_labels(label_type)[0].score for t in current_span] 

737 span_score = sum(scores) / len(scores) 

738 if span_score > min_score: 

739 span = Span(current_span) 

740 span.add_label( 

741 typename=label_type, 

742 value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0], 

743 score=span_score) 

744 spans.append(span) 

745 

746 current_span = [] 

747 tags = defaultdict(lambda: 0.0) 

748 

749 if in_span: 

750 current_span.append(token) 

751 weight = 1.1 if starts_new_span else 1.0 

752 tags[tag_value[2:]] += weight 

753 

754 # remember previous tag 

755 previous_tag_value = tag_value 

756 

757 if len(current_span) > 0: 

758 scores = [t.get_labels(label_type)[0].score for t in current_span] 

759 span_score = sum(scores) / len(scores) 

760 if span_score > min_score: 

761 span = Span(current_span) 

762 span.add_label( 

763 typename=label_type, 

764 value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0], 

765 score=span_score) 

766 spans.append(span) 

767 

768 return spans 

769 

770 def get_spans(self, label_type: Optional[str] = None, min_score=-1) -> List[Span]: 

771 

772 spans: List[Span] = [] 

773 

774 # if label type is explicitly specified, get spans for this label type 

775 if label_type: 

776 return self._add_spans_internal(spans, label_type, min_score) 

777 

778 # else determine all label types in sentence and get all spans 

779 label_types = [] 

780 for token in self: 

781 for annotation in token.annotation_layers.keys(): 

782 if annotation not in label_types: label_types.append(annotation) 

783 

784 for label_type in label_types: 

785 self._add_spans_internal(spans, label_type, min_score) 

786 return spans 

787 

788 @property 

789 def embedding(self): 

790 return self.get_embedding() 

791 

792 def set_embedding(self, name: str, vector: torch.tensor): 

793 device = flair.device 

794 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0: 

795 device = next(iter(self._embeddings.values())).device 

796 if device != vector.device: 

797 vector = vector.to(device) 

798 self._embeddings[name] = vector 

799 

800 def get_embedding(self, names: Optional[List[str]] = None) -> torch.tensor: 

801 embeddings = [] 

802 for embed in sorted(self._embeddings.keys()): 

803 if names and embed not in names: continue 

804 embedding = self._embeddings[embed] 

805 embeddings.append(embedding) 

806 

807 if embeddings: 

808 return torch.cat(embeddings, dim=0) 

809 

810 return torch.Tensor() 

811 

812 def to(self, device: str, pin_memory: bool = False): 

813 

814 # move sentence embeddings to device 

815 for name, vector in self._embeddings.items(): 

816 if str(vector.device) != str(device): 

817 if pin_memory: 

818 self._embeddings[name] = vector.to( 

819 device, non_blocking=True 

820 ).pin_memory() 

821 else: 

822 self._embeddings[name] = vector.to(device, non_blocking=True) 

823 

824 # move token embeddings to device 

825 for token in self: 

826 token.to(device, pin_memory) 

827 

828 def clear_embeddings(self, embedding_names: List[str] = None): 

829 

830 # clear sentence embeddings 

831 if embedding_names is None: 

832 self._embeddings: Dict = {} 

833 else: 

834 for name in embedding_names: 

835 if name in self._embeddings.keys(): 

836 del self._embeddings[name] 

837 

838 # clear token embeddings 

839 for token in self: 

840 token.clear_embeddings(embedding_names) 

841 

842 def to_tagged_string(self, main_tag=None) -> str: 

843 list = [] 

844 for token in self.tokens: 

845 list.append(token.text) 

846 

847 tags: List[str] = [] 

848 for label_type in token.annotation_layers.keys(): 

849 

850 if main_tag is not None and main_tag != label_type: 

851 continue 

852 

853 if token.get_labels(label_type)[0].value == "O": 

854 continue 

855 if token.get_labels(label_type)[0].value == "_": 

856 continue 

857 

858 tags.append(token.get_labels(label_type)[0].value) 

859 all_tags = "<" + "/".join(tags) + ">" 

860 if all_tags != "<>": 

861 list.append(all_tags) 

862 return " ".join(list) 

863 

864 def to_tokenized_string(self) -> str: 

865 

866 if self.tokenized is None: 

867 self.tokenized = " ".join([t.text for t in self.tokens]) 

868 

869 return self.tokenized 

870 

871 def to_plain_string(self): 

872 plain = "" 

873 for token in self.tokens: 

874 plain += token.text 

875 if token.whitespace_after: 

876 plain += " " 

877 return plain.rstrip() 

878 

879 def convert_tag_scheme(self, tag_type: str = "ner", target_scheme: str = "iob"): 

880 

881 tags: List[Label] = [] 

882 for token in self.tokens: 

883 tags.append(token.get_tag(tag_type)) 

884 

885 if target_scheme == "iob": 

886 iob2(tags) 

887 

888 if target_scheme == "iobes": 

889 iob2(tags) 

890 tags = iob_iobes(tags) 

891 

892 for index, tag in enumerate(tags): 

893 self.tokens[index].set_label(tag_type, tag) 

894 

895 def infer_space_after(self): 

896 """ 

897 Heuristics in case you wish to infer whitespace_after values for tokenized text. This is useful for some old NLP 

898 tasks (such as CoNLL-03 and CoNLL-2000) that provide only tokenized data with no info of original whitespacing. 

899 :return: 

900 """ 

901 last_token = None 

902 quote_count: int = 0 

903 # infer whitespace after field 

904 

905 for token in self.tokens: 

906 if token.text == '"': 

907 quote_count += 1 

908 if quote_count % 2 != 0: 

909 token.whitespace_after = False 

910 elif last_token is not None: 

911 last_token.whitespace_after = False 

912 

913 if last_token is not None: 

914 

915 if token.text in [".", ":", ",", ";", ")", "n't", "!", "?"]: 

916 last_token.whitespace_after = False 

917 

918 if token.text.startswith("'"): 

919 last_token.whitespace_after = False 

920 

921 if token.text in ["("]: 

922 token.whitespace_after = False 

923 

924 last_token = token 

925 return self 

926 

927 def to_original_text(self) -> str: 

928 if len(self.tokens) > 0 and (self.tokens[0].start_pos is None): 

929 return " ".join([t.text for t in self.tokens]) 

930 str = "" 

931 pos = 0 

932 for t in self.tokens: 

933 while t.start_pos > pos: 

934 str += " " 

935 pos += 1 

936 

937 str += t.text 

938 pos += len(t.text) 

939 

940 return str 

941 

942 def to_dict(self, tag_type: str = None): 

943 labels = [] 

944 entities = [] 

945 

946 if tag_type: 

947 entities = [span.to_dict() for span in self.get_spans(tag_type)] 

948 if self.labels: 

949 labels = [l.to_dict() for l in self.labels] 

950 

951 return {"text": self.to_original_text(), "labels": labels, "entities": entities} 

952 

953 def __getitem__(self, idx: int) -> Token: 

954 return self.tokens[idx] 

955 

956 def __iter__(self): 

957 return iter(self.tokens) 

958 

959 def __len__(self) -> int: 

960 return len(self.tokens) 

961 

962 def __repr__(self): 

963 tagged_string = self.to_tagged_string() 

964 tokenized_string = self.to_tokenized_string() 

965 

966 # add Sentence labels to output if they exist 

967 sentence_labels = f" − Sentence-Labels: {self.annotation_layers}" if self.annotation_layers != {} else "" 

968 

969 # add Token labels to output if they exist 

970 token_labels = f' − Token-Labels: "{tagged_string}"' if tokenized_string != tagged_string else "" 

971 

972 return f'Sentence: "{tokenized_string}" [− Tokens: {len(self)}{token_labels}{sentence_labels}]' 

973 

974 def __copy__(self): 

975 s = Sentence() 

976 for token in self.tokens: 

977 nt = Token(token.text) 

978 for tag_type in token.tags: 

979 nt.add_label( 

980 tag_type, 

981 token.get_tag(tag_type).value, 

982 token.get_tag(tag_type).score, 

983 ) 

984 

985 s.add_token(nt) 

986 return s 

987 

988 def __str__(self) -> str: 

989 

990 tagged_string = self.to_tagged_string() 

991 tokenized_string = self.to_tokenized_string() 

992 

993 # add Sentence labels to output if they exist 

994 sentence_labels = f" − Sentence-Labels: {self.annotation_layers}" if self.annotation_layers != {} else "" 

995 

996 # add Token labels to output if they exist 

997 token_labels = f' − Token-Labels: "{tagged_string}"' if tokenized_string != tagged_string else "" 

998 

999 return f'Sentence: "{tokenized_string}" [− Tokens: {len(self)}{token_labels}{sentence_labels}]' 

1000 

1001 def get_language_code(self) -> str: 

1002 if self.language_code is None: 

1003 import langdetect 

1004 

1005 try: 

1006 self.language_code = langdetect.detect(self.to_plain_string()) 

1007 except: 

1008 self.language_code = "en" 

1009 

1010 return self.language_code 

1011 

1012 @staticmethod 

1013 def _restore_windows_1252_characters(text: str) -> str: 

1014 def to_windows_1252(match): 

1015 try: 

1016 return bytes([ord(match.group(0))]).decode("windows-1252") 

1017 except UnicodeDecodeError: 

1018 # No character at the corresponding code point: remove it 

1019 return "" 

1020 

1021 return re.sub(r"[\u0080-\u0099]", to_windows_1252, text) 

1022 

1023 def next_sentence(self): 

1024 """ 

1025 Get the next sentence in the document (works only if context is set through dataloader or elsewhere) 

1026 :return: next Sentence in document if set, otherwise None 

1027 """ 

1028 if '_next_sentence' in self.__dict__.keys(): 

1029 return self._next_sentence 

1030 

1031 if '_position_in_dataset' in self.__dict__.keys(): 

1032 dataset = self._position_in_dataset[0] 

1033 index = self._position_in_dataset[1] + 1 

1034 if index < len(dataset): 

1035 return dataset[index] 

1036 

1037 return None 

1038 

1039 def previous_sentence(self): 

1040 """ 

1041 Get the previous sentence in the document (works only if context is set through dataloader or elsewhere) 

1042 :return: previous Sentence in document if set, otherwise None 

1043 """ 

1044 if '_previous_sentence' in self.__dict__.keys(): 

1045 return self._previous_sentence 

1046 

1047 if '_position_in_dataset' in self.__dict__.keys(): 

1048 dataset = self._position_in_dataset[0] 

1049 index = self._position_in_dataset[1] - 1 

1050 if index >= 0: 

1051 return dataset[index] 

1052 

1053 return None 

1054 

1055 def is_context_set(self) -> bool: 

1056 """ 

1057 Return True or False depending on whether context is set (for instance in dataloader or elsewhere) 

1058 :return: True if context is set, else False 

1059 """ 

1060 return '_previous_sentence' in self.__dict__.keys() or '_position_in_dataset' in self.__dict__.keys() 

1061 

1062 def get_labels(self, label_type: str = None): 

1063 

1064 # TODO: crude hack - replace with something better 

1065 if label_type: 

1066 spans = self.get_spans(label_type) 

1067 for span in spans: 

1068 self.add_complex_label(label_type, label=SpanLabel(span, span.tag, span.score)) 

1069 

1070 if label_type is None: 

1071 return self.labels 

1072 

1073 return self.annotation_layers[label_type] if label_type in self.annotation_layers else [] 

1074 

1075 

1076class Image(DataPoint): 

1077 

1078 def __init__(self, data=None, imageURL=None): 

1079 super().__init__() 

1080 

1081 self.data = data 

1082 self._embeddings: Dict = {} 

1083 self.imageURL = imageURL 

1084 

1085 @property 

1086 def embedding(self): 

1087 return self.get_embedding() 

1088 

1089 def __str__(self): 

1090 

1091 image_repr = self.data.size() if self.data else "" 

1092 image_url = self.imageURL if self.imageURL else "" 

1093 

1094 return f"Image: {image_repr} {image_url}" 

1095 

1096 def get_embedding(self) -> torch.tensor: 

1097 embeddings = [ 

1098 self._embeddings[embed] for embed in sorted(self._embeddings.keys()) 

1099 ] 

1100 

1101 if embeddings: 

1102 return torch.cat(embeddings, dim=0) 

1103 

1104 return torch.tensor([], device=flair.device) 

1105 

1106 def set_embedding(self, name: str, vector: torch.tensor): 

1107 device = flair.device 

1108 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0: 

1109 device = next(iter(self._embeddings.values())).device 

1110 if device != vector.device: 

1111 vector = vector.to(device) 

1112 self._embeddings[name] = vector 

1113 

1114 def to(self, device: str, pin_memory: bool = False): 

1115 for name, vector in self._embeddings.items(): 

1116 if str(vector.device) != str(device): 

1117 if pin_memory: 

1118 self._embeddings[name] = vector.to( 

1119 device, non_blocking=True 

1120 ).pin_memory() 

1121 else: 

1122 self._embeddings[name] = vector.to(device, non_blocking=True) 

1123 

1124 def clear_embeddings(self, embedding_names: List[str] = None): 

1125 if embedding_names is None: 

1126 self._embeddings: Dict = {} 

1127 else: 

1128 for name in embedding_names: 

1129 if name in self._embeddings.keys(): 

1130 del self._embeddings[name] 

1131 

1132 

1133class FlairDataset(Dataset): 

1134 @abstractmethod 

1135 def is_in_memory(self) -> bool: 

1136 pass 

1137 

1138 

1139class Corpus: 

1140 def __init__( 

1141 self, 

1142 train: FlairDataset = None, 

1143 dev: FlairDataset = None, 

1144 test: FlairDataset = None, 

1145 name: str = "corpus", 

1146 sample_missing_splits: Union[bool, str] = True, 

1147 ): 

1148 # set name 

1149 self.name: str = name 

1150 

1151 # abort if no data is provided 

1152 if not train and not dev and not test: 

1153 raise RuntimeError('No data provided when initializing corpus object.') 

1154 

1155 # sample test data from train if none is provided 

1156 if test is None and sample_missing_splits and train and not sample_missing_splits == 'only_dev': 

1157 train_length = len(train) 

1158 test_size: int = round(train_length / 10) 

1159 splits = randomly_split_into_two_datasets(train, test_size) 

1160 test = splits[0] 

1161 train = splits[1] 

1162 

1163 # sample dev data from train if none is provided 

1164 if dev is None and sample_missing_splits and train and not sample_missing_splits == 'only_test': 

1165 train_length = len(train) 

1166 dev_size: int = round(train_length / 10) 

1167 splits = randomly_split_into_two_datasets(train, dev_size) 

1168 dev = splits[0] 

1169 train = splits[1] 

1170 

1171 # set train dev and test data 

1172 self._train: FlairDataset = train 

1173 self._test: FlairDataset = test 

1174 self._dev: FlairDataset = dev 

1175 

1176 @property 

1177 def train(self) -> FlairDataset: 

1178 return self._train 

1179 

1180 @property 

1181 def dev(self) -> FlairDataset: 

1182 return self._dev 

1183 

1184 @property 

1185 def test(self) -> FlairDataset: 

1186 return self._test 

1187 

1188 def downsample(self, percentage: float = 0.1, downsample_train=True, downsample_dev=True, downsample_test=True): 

1189 

1190 if downsample_train and self._train: 

1191 self._train = self._downsample_to_proportion(self.train, percentage) 

1192 

1193 if downsample_dev and self._dev: 

1194 self._dev = self._downsample_to_proportion(self.dev, percentage) 

1195 

1196 if downsample_test and self._test: 

1197 self._test = self._downsample_to_proportion(self.test, percentage) 

1198 

1199 return self 

1200 

1201 def filter_empty_sentences(self): 

1202 log.info("Filtering empty sentences") 

1203 self._train = Corpus._filter_empty_sentences(self._train) 

1204 self._test = Corpus._filter_empty_sentences(self._test) 

1205 self._dev = Corpus._filter_empty_sentences(self._dev) 

1206 log.info(self) 

1207 

1208 def filter_long_sentences(self, max_charlength: int): 

1209 log.info("Filtering long sentences") 

1210 self._train = Corpus._filter_long_sentences(self._train, max_charlength) 

1211 self._test = Corpus._filter_long_sentences(self._test, max_charlength) 

1212 self._dev = Corpus._filter_long_sentences(self._dev, max_charlength) 

1213 log.info(self) 

1214 

1215 @staticmethod 

1216 def _filter_long_sentences(dataset, max_charlength: int) -> Dataset: 

1217 

1218 # find out empty sentence indices 

1219 empty_sentence_indices = [] 

1220 non_empty_sentence_indices = [] 

1221 index = 0 

1222 

1223 from flair.datasets import DataLoader 

1224 

1225 for batch in DataLoader(dataset): 

1226 for sentence in batch: 

1227 if len(sentence.to_plain_string()) > max_charlength: 

1228 empty_sentence_indices.append(index) 

1229 else: 

1230 non_empty_sentence_indices.append(index) 

1231 index += 1 

1232 

1233 # create subset of non-empty sentence indices 

1234 subset = Subset(dataset, non_empty_sentence_indices) 

1235 

1236 return subset 

1237 

1238 @staticmethod 

1239 def _filter_empty_sentences(dataset) -> Dataset: 

1240 

1241 # find out empty sentence indices 

1242 empty_sentence_indices = [] 

1243 non_empty_sentence_indices = [] 

1244 index = 0 

1245 

1246 from flair.datasets import DataLoader 

1247 

1248 for batch in DataLoader(dataset): 

1249 for sentence in batch: 

1250 if len(sentence) == 0: 

1251 empty_sentence_indices.append(index) 

1252 else: 

1253 non_empty_sentence_indices.append(index) 

1254 index += 1 

1255 

1256 # create subset of non-empty sentence indices 

1257 subset = Subset(dataset, non_empty_sentence_indices) 

1258 

1259 return subset 

1260 

1261 def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary: 

1262 """ 

1263 Creates a dictionary of all tokens contained in the corpus. 

1264 By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary. 

1265 If there are more than `max_tokens` tokens in the corpus, the most frequent tokens are added first. 

1266 If `min_freq` is set the a value greater than 1 only tokens occurring more than `min_freq` times are considered 

1267 to be added to the dictionary. 

1268 :param max_tokens: the maximum number of tokens that should be added to the dictionary (-1 = take all tokens) 

1269 :param min_freq: a token needs to occur at least `min_freq` times to be added to the dictionary (-1 = there is no limitation) 

1270 :return: dictionary of tokens 

1271 """ 

1272 tokens = self._get_most_common_tokens(max_tokens, min_freq) 

1273 

1274 vocab_dictionary: Dictionary = Dictionary() 

1275 for token in tokens: 

1276 vocab_dictionary.add_item(token) 

1277 

1278 return vocab_dictionary 

1279 

1280 def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]: 

1281 tokens_and_frequencies = Counter(self._get_all_tokens()) 

1282 tokens_and_frequencies = tokens_and_frequencies.most_common() 

1283 

1284 tokens = [] 

1285 for token, freq in tokens_and_frequencies: 

1286 if (min_freq != -1 and freq < min_freq) or ( 

1287 max_tokens != -1 and len(tokens) == max_tokens 

1288 ): 

1289 break 

1290 tokens.append(token) 

1291 return tokens 

1292 

1293 def _get_all_tokens(self) -> List[str]: 

1294 tokens = list(map((lambda s: s.tokens), self.train)) 

1295 tokens = [token for sublist in tokens for token in sublist] 

1296 return list(map((lambda t: t.text), tokens)) 

1297 

1298 @staticmethod 

1299 def _downsample_to_proportion(dataset: Dataset, proportion: float): 

1300 

1301 sampled_size: int = round(len(dataset) * proportion) 

1302 splits = randomly_split_into_two_datasets(dataset, sampled_size) 

1303 return splits[0] 

1304 

1305 def obtain_statistics( 

1306 self, label_type: str = None, pretty_print: bool = True 

1307 ) -> dict: 

1308 """ 

1309 Print statistics about the class distribution (only labels of sentences are taken into account) and sentence 

1310 sizes. 

1311 """ 

1312 json_string = { 

1313 "TRAIN": self._obtain_statistics_for(self.train, "TRAIN", label_type), 

1314 "TEST": self._obtain_statistics_for(self.test, "TEST", label_type), 

1315 "DEV": self._obtain_statistics_for(self.dev, "DEV", label_type), 

1316 } 

1317 if pretty_print: 

1318 import json 

1319 

1320 json_string = json.dumps(json_string, indent=4) 

1321 return json_string 

1322 

1323 @staticmethod 

1324 def _obtain_statistics_for(sentences, name, tag_type) -> dict: 

1325 if len(sentences) == 0: 

1326 return {} 

1327 

1328 classes_to_count = Corpus._count_sentence_labels(sentences) 

1329 tags_to_count = Corpus._count_token_labels(sentences, tag_type) 

1330 tokens_per_sentence = Corpus._get_tokens_per_sentence(sentences) 

1331 

1332 label_size_dict = {} 

1333 for l, c in classes_to_count.items(): 

1334 label_size_dict[l] = c 

1335 

1336 tag_size_dict = {} 

1337 for l, c in tags_to_count.items(): 

1338 tag_size_dict[l] = c 

1339 

1340 return { 

1341 "dataset": name, 

1342 "total_number_of_documents": len(sentences), 

1343 "number_of_documents_per_class": label_size_dict, 

1344 "number_of_tokens_per_tag": tag_size_dict, 

1345 "number_of_tokens": { 

1346 "total": sum(tokens_per_sentence), 

1347 "min": min(tokens_per_sentence), 

1348 "max": max(tokens_per_sentence), 

1349 "avg": sum(tokens_per_sentence) / len(sentences), 

1350 }, 

1351 } 

1352 

1353 @staticmethod 

1354 def _get_tokens_per_sentence(sentences): 

1355 return list(map(lambda x: len(x.tokens), sentences)) 

1356 

1357 @staticmethod 

1358 def _count_sentence_labels(sentences): 

1359 label_count = defaultdict(lambda: 0) 

1360 for sent in sentences: 

1361 for label in sent.labels: 

1362 label_count[label.value] += 1 

1363 return label_count 

1364 

1365 @staticmethod 

1366 def _count_token_labels(sentences, label_type): 

1367 label_count = defaultdict(lambda: 0) 

1368 for sent in sentences: 

1369 for token in sent.tokens: 

1370 if label_type in token.annotation_layers.keys(): 

1371 label = token.get_tag(label_type) 

1372 label_count[label.value] += 1 

1373 return label_count 

1374 

1375 def __str__(self) -> str: 

1376 return "Corpus: %d train + %d dev + %d test sentences" % ( 

1377 len(self.train) if self.train else 0, 

1378 len(self.dev) if self.dev else 0, 

1379 len(self.test) if self.test else 0, 

1380 ) 

1381 

1382 def make_label_dictionary(self, label_type: str) -> Dictionary: 

1383 """ 

1384 Creates a dictionary of all labels assigned to the sentences in the corpus. 

1385 :return: dictionary of labels 

1386 """ 

1387 label_dictionary: Dictionary = Dictionary(add_unk=True) 

1388 label_dictionary.multi_label = False 

1389 

1390 from flair.datasets import DataLoader 

1391 

1392 datasets = [self.train] 

1393 

1394 data = ConcatDataset(datasets) 

1395 

1396 loader = DataLoader(data, batch_size=1) 

1397 

1398 log.info("Computing label dictionary. Progress:") 

1399 

1400 # if there are token labels of provided type, use these. Otherwise use sentence labels 

1401 token_labels_exist = False 

1402 

1403 all_label_types = Counter() 

1404 all_sentence_labels = [] 

1405 for batch in Tqdm.tqdm(iter(loader)): 

1406 

1407 for sentence in batch: 

1408 

1409 # check for labels of words 

1410 if isinstance(sentence, Sentence): 

1411 for token in sentence.tokens: 

1412 all_label_types.update(token.annotation_layers.keys()) 

1413 for label in token.get_labels(label_type): 

1414 label_dictionary.add_item(label.value) 

1415 token_labels_exist = True 

1416 

1417 # if we are looking for sentence-level labels 

1418 if not token_labels_exist: 

1419 # check if sentence itself has labels 

1420 labels = sentence.get_labels(label_type) 

1421 all_label_types.update(sentence.annotation_layers.keys()) 

1422 

1423 for label in labels: 

1424 if label.value not in all_sentence_labels: all_sentence_labels.append(label.value) 

1425 

1426 if not label_dictionary.multi_label: 

1427 if len(labels) > 1: 

1428 label_dictionary.multi_label = True 

1429 

1430 # if this is not a token-level prediction problem, add sentence-level labels to dictionary 

1431 if not token_labels_exist: 

1432 for label in all_sentence_labels: 

1433 label_dictionary.add_item(label) 

1434 

1435 if len(label_dictionary.idx2item) == 0: 

1436 log.error( 

1437 f"Corpus contains only the labels: {', '.join([f'{label[0]} (#{label[1]})' for label in all_label_types.most_common()])}") 

1438 log.error(f"You specified as label_type='{label_type}' which is not in this dataset!") 

1439 

1440 raise Exception 

1441 

1442 log.info( 

1443 f"Corpus contains the labels: {', '.join([label[0] + f' (#{label[1]})' for label in all_label_types.most_common()])}") 

1444 log.info(f"Created (for label '{label_type}') {label_dictionary}") 

1445 

1446 return label_dictionary 

1447 

1448 def get_label_distribution(self): 

1449 class_to_count = defaultdict(lambda: 0) 

1450 for sent in self.train: 

1451 for label in sent.labels: 

1452 class_to_count[label.value] += 1 

1453 return class_to_count 

1454 

1455 def get_all_sentences(self) -> Dataset: 

1456 parts = [] 

1457 if self.train: parts.append(self.train) 

1458 if self.dev: parts.append(self.dev) 

1459 if self.test: parts.append(self.test) 

1460 return ConcatDataset(parts) 

1461 

1462 @deprecated(version="0.8", reason="Use 'make_label_dictionary' instead.") 

1463 def make_tag_dictionary(self, tag_type: str) -> Dictionary: 

1464 

1465 # Make the tag dictionary 

1466 tag_dictionary: Dictionary = Dictionary(add_unk=False) 

1467 tag_dictionary.add_item("O") 

1468 for sentence in self.get_all_sentences(): 

1469 for token in sentence.tokens: 

1470 tag_dictionary.add_item(token.get_tag(tag_type).value) 

1471 tag_dictionary.add_item("<START>") 

1472 tag_dictionary.add_item("<STOP>") 

1473 return tag_dictionary 

1474 

1475 

1476class MultiCorpus(Corpus): 

1477 def __init__(self, corpora: List[Corpus], name: str = "multicorpus", **corpusargs): 

1478 self.corpora: List[Corpus] = corpora 

1479 

1480 train_parts = [] 

1481 dev_parts = [] 

1482 test_parts = [] 

1483 for corpus in self.corpora: 

1484 if corpus.train: train_parts.append(corpus.train) 

1485 if corpus.dev: dev_parts.append(corpus.dev) 

1486 if corpus.test: test_parts.append(corpus.test) 

1487 

1488 super(MultiCorpus, self).__init__( 

1489 ConcatDataset(train_parts) if len(train_parts) > 0 else None, 

1490 ConcatDataset(dev_parts) if len(dev_parts) > 0 else None, 

1491 ConcatDataset(test_parts) if len(test_parts) > 0 else None, 

1492 name=name, 

1493 **corpusargs, 

1494 ) 

1495 

1496 def __str__(self): 

1497 output = f"MultiCorpus: " \ 

1498 f"{len(self.train) if self.train else 0} train + " \ 

1499 f"{len(self.dev) if self.dev else 0} dev + " \ 

1500 f"{len(self.test) if self.test else 0} test sentences\n - " 

1501 output += "\n - ".join([f'{type(corpus).__name__} {str(corpus)} - {corpus.name}' for corpus in self.corpora]) 

1502 return output 

1503 

1504 

1505def iob2(tags): 

1506 """ 

1507 Check that tags have a valid IOB format. 

1508 Tags in IOB1 format are converted to IOB2. 

1509 """ 

1510 for i, tag in enumerate(tags): 

1511 if tag.value == "O": 

1512 continue 

1513 split = tag.value.split("-") 

1514 if len(split) != 2 or split[0] not in ["I", "B"]: 

1515 return False 

1516 if split[0] == "B": 

1517 continue 

1518 elif i == 0 or tags[i - 1].value == "O": # conversion IOB1 to IOB2 

1519 tags[i].value = "B" + tag.value[1:] 

1520 elif tags[i - 1].value[1:] == tag.value[1:]: 

1521 continue 

1522 else: # conversion IOB1 to IOB2 

1523 tags[i].value = "B" + tag.value[1:] 

1524 return True 

1525 

1526 

1527def iob_iobes(tags): 

1528 """ 

1529 IOB -> IOBES 

1530 """ 

1531 new_tags = [] 

1532 for i, tag in enumerate(tags): 

1533 if tag.value == "O" or tag.value == "": 

1534 new_tags.append("O") 

1535 elif tag.value.split("-")[0] == "B": 

1536 if i + 1 != len(tags) and tags[i + 1].value.split("-")[0] == "I": 

1537 new_tags.append(tag.value) 

1538 else: 

1539 new_tags.append(tag.value.replace("B-", "S-")) 

1540 elif tag.value.split("-")[0] == "I": 

1541 if i + 1 < len(tags) and tags[i + 1].value.split("-")[0] == "I": 

1542 new_tags.append(tag.value) 

1543 else: 

1544 new_tags.append(tag.value.replace("I-", "E-")) 

1545 else: 

1546 raise Exception("Invalid IOB format!") 

1547 return new_tags 

1548 

1549 

1550def randomly_split_into_two_datasets(dataset, length_of_first): 

1551 import random 

1552 indices = [i for i in range(len(dataset))] 

1553 random.shuffle(indices) 

1554 

1555 first_dataset = indices[:length_of_first] 

1556 second_dataset = indices[length_of_first:] 

1557 first_dataset.sort() 

1558 second_dataset.sort() 

1559 

1560 return [Subset(dataset, first_dataset), Subset(dataset, second_dataset)]