Coverage for flair/flair/data.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import re
3from abc import abstractmethod, ABC
4from collections import Counter
5from collections import defaultdict
6from operator import itemgetter
7from typing import List, Dict, Union, Optional
9from deprecated import deprecated
10import torch
11from torch.utils.data import Dataset
12from torch.utils.data.dataset import ConcatDataset, Subset
14import flair
15from flair.file_utils import Tqdm
17log = logging.getLogger("flair")
20class Dictionary:
21 """
22 This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.
23 """
25 def __init__(self, add_unk=True):
26 # init dictionaries
27 self.item2idx: Dict[str, int] = {}
28 self.idx2item: List[str] = []
29 self.add_unk = add_unk
30 # in order to deal with unknown tokens, add <unk>
31 if add_unk:
32 self.add_item("<unk>")
34 def remove_item(self, item: str):
36 item = item.encode("utf-8")
37 if item in self.item2idx:
38 self.idx2item.remove(item)
39 del self.item2idx[item]
41 def add_item(self, item: str) -> int:
42 """
43 add string - if already in dictionary returns its ID. if not in dictionary, it will get a new ID.
44 :param item: a string for which to assign an id.
45 :return: ID of string
46 """
47 item = item.encode("utf-8")
48 if item not in self.item2idx:
49 self.idx2item.append(item)
50 self.item2idx[item] = len(self.idx2item) - 1
51 return self.item2idx[item]
53 def get_idx_for_item(self, item: str) -> int:
54 """
55 returns the ID of the string, otherwise 0
56 :param item: string for which ID is requested
57 :return: ID of string, otherwise 0
58 """
59 item_encoded = item.encode("utf-8")
60 if item_encoded in self.item2idx.keys():
61 return self.item2idx[item_encoded]
62 elif self.add_unk:
63 return 0
64 else:
65 log.error(f"The string '{item}' is not in dictionary! Dictionary contains only: {self.get_items()}")
66 log.error(
67 "You can create a Dictionary that handles unknown items with an <unk>-key by setting add_unk = True in the construction.")
68 raise IndexError
70 def get_idx_for_items(self, items: List[str]) -> List[int]:
71 """
72 returns the IDs for each item of the list of string, otherwise 0 if not found
73 :param items: List of string for which IDs are requested
74 :return: List of ID of strings
75 """
76 if not hasattr(self, "item2idx_not_encoded"):
77 d = dict(
78 [(key.decode("UTF-8"), value) for key, value in self.item2idx.items()]
79 )
80 self.item2idx_not_encoded = defaultdict(int, d)
82 if not items:
83 return []
84 results = itemgetter(*items)(self.item2idx_not_encoded)
85 if isinstance(results, int):
86 return [results]
87 return list(results)
89 def get_items(self) -> List[str]:
90 items = []
91 for item in self.idx2item:
92 items.append(item.decode("UTF-8"))
93 return items
95 def __len__(self) -> int:
96 return len(self.idx2item)
98 def get_item_for_index(self, idx):
99 return self.idx2item[idx].decode("UTF-8")
101 def save(self, savefile):
102 import pickle
104 with open(savefile, "wb") as f:
105 mappings = {"idx2item": self.idx2item, "item2idx": self.item2idx}
106 pickle.dump(mappings, f)
108 def __setstate__(self, d):
109 self.__dict__ = d
110 # set 'add_unk' if the dictionary was created with a version of Flair older than 0.9
111 if 'add_unk' not in self.__dict__.keys():
112 self.__dict__['add_unk'] = True if b'<unk>' in self.__dict__['idx2item'] else False
114 @classmethod
115 def load_from_file(cls, filename: str):
116 import pickle
118 f = open(filename, "rb")
119 mappings = pickle.load(f, encoding="latin1")
120 idx2item = mappings["idx2item"]
121 item2idx = mappings["item2idx"]
122 f.close()
124 # set 'add_unk' depending on whether <unk> is a key
125 add_unk = True if b'<unk>' in idx2item else False
127 dictionary: Dictionary = Dictionary(add_unk=add_unk)
128 dictionary.item2idx = item2idx
129 dictionary.idx2item = idx2item
130 return dictionary
132 @classmethod
133 def load(cls, name: str):
134 from flair.file_utils import cached_path
135 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/characters"
136 if name == "chars" or name == "common-chars":
137 char_dict = cached_path(f"{hu_path}/common_characters", cache_dir="datasets")
138 return Dictionary.load_from_file(char_dict)
140 if name == "chars-large" or name == "common-chars-large":
141 char_dict = cached_path(f"{hu_path}/common_characters_large", cache_dir="datasets")
142 return Dictionary.load_from_file(char_dict)
144 if name == "chars-xl" or name == "common-chars-xl":
145 char_dict = cached_path(f"{hu_path}/common_characters_xl", cache_dir="datasets")
146 return Dictionary.load_from_file(char_dict)
148 return Dictionary.load_from_file(name)
150 def __str__(self):
151 tags = ', '.join(self.get_item_for_index(i) for i in range(min(len(self), 50)))
152 return f"Dictionary with {len(self)} tags: {tags}"
155class Label:
156 """
157 This class represents a label. Each label has a value and optionally a confidence score. The
158 score needs to be between 0.0 and 1.0. Default value for the score is 1.0.
159 """
161 def __init__(self, value: str, score: float = 1.0):
162 self._value = value
163 self._score = score
164 super().__init__()
166 def set_value(self, value: str, score: float = 1.0):
167 self.value = value
168 self.score = score
170 def spawn(self, value: str, score: float = 1.0):
171 return Label(value, score)
173 @property
174 def value(self):
175 return self._value
177 @value.setter
178 def value(self, value):
179 if not value and value != "":
180 raise ValueError(
181 "Incorrect label value provided. Label value needs to be set."
182 )
183 else:
184 self._value = value
186 @property
187 def score(self):
188 return self._score
190 @score.setter
191 def score(self, score):
192 self._score = score
194 def to_dict(self):
195 return {"value": self.value, "confidence": self.score}
197 def __str__(self):
198 return f"{self._value} ({round(self._score, 4)})"
200 def __repr__(self):
201 return f"{self._value} ({round(self._score, 4)})"
203 def __eq__(self, other):
204 return self.value == other.value and self.score == other.score
206 @property
207 def identifier(self):
208 return ""
211class SpanLabel(Label):
212 def __init__(self, span, value: str, score: float = 1.0):
213 super().__init__(value, score)
214 self.span = span
216 def spawn(self, value: str, score: float = 1.0):
217 return SpanLabel(self.span, value, score)
219 def __str__(self):
220 return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})"
222 def __repr__(self):
223 return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})"
225 def __len__(self):
226 return len(self.span)
228 def __eq__(self, other):
229 return self.value == other.value and self.score == other.score and self.span.id_text == other.span.id_text
231 @property
232 def identifier(self):
233 return f"{self.span.id_text}"
236class RelationLabel(Label):
237 def __init__(self, head, tail, value: str, score: float = 1.0):
238 super().__init__(value, score)
239 self.head = head
240 self.tail = tail
242 def spawn(self, value: str, score: float = 1.0):
243 return RelationLabel(self.head, self.tail, value, score)
245 def __str__(self):
246 return f"{self._value} [{self.head.id_text} -> {self.tail.id_text}] ({round(self._score, 4)})"
248 def __repr__(self):
249 return f"{self._value} from {self.head.id_text} -> {self.tail.id_text} ({round(self._score, 4)})"
251 def __len__(self):
252 return len(self.head) + len(self.tail)
254 def __eq__(self, other):
255 return self.value == other.value \
256 and self.score == other.score \
257 and self.head.id_text == other.head.id_text \
258 and self.tail.id_text == other.tail.id_text
260 @property
261 def identifier(self):
262 return f"{self.head.id_text} -> {self.tail.id_text}"
265class DataPoint:
266 """
267 This is the parent class of all data points in Flair (including Token, Sentence, Image, etc.). Each DataPoint
268 must be embeddable (hence the abstract property embedding() and methods to() and clear_embeddings()). Also,
269 each DataPoint may have Labels in several layers of annotation (hence the functions add_label(), get_labels()
270 and the property 'label')
271 """
273 def __init__(self):
274 self.annotation_layers = {}
276 @property
277 @abstractmethod
278 def embedding(self):
279 pass
281 @abstractmethod
282 def to(self, device: str, pin_memory: bool = False):
283 pass
285 @abstractmethod
286 def clear_embeddings(self, embedding_names: List[str] = None):
287 pass
289 def add_label(self, typename: str, value: str, score: float = 1.):
291 if typename not in self.annotation_layers:
292 self.annotation_layers[typename] = [Label(value, score)]
293 else:
294 self.annotation_layers[typename].append(Label(value, score))
296 return self
298 def add_complex_label(self, typename: str, label: Label):
300 if typename in self.annotation_layers and label in self.annotation_layers[typename]:
301 return self
303 if typename not in self.annotation_layers:
304 self.annotation_layers[typename] = [label]
305 else:
306 self.annotation_layers[typename].append(label)
308 return self
310 def set_label(self, typename: str, value: str, score: float = 1.):
311 self.annotation_layers[typename] = [Label(value, score)]
312 return self
314 def remove_labels(self, typename: str):
315 if typename in self.annotation_layers.keys():
316 del self.annotation_layers[typename]
318 def get_labels(self, typename: str = None):
319 if typename is None:
320 return self.labels
322 return self.annotation_layers[typename] if typename in self.annotation_layers else []
324 @property
325 def labels(self) -> List[Label]:
326 all_labels = []
327 for key in self.annotation_layers.keys():
328 all_labels.extend(self.annotation_layers[key])
329 return all_labels
332class DataPair(DataPoint):
333 def __init__(self, first: DataPoint, second: DataPoint):
334 super().__init__()
335 self.first = first
336 self.second = second
338 def to(self, device: str, pin_memory: bool = False):
339 self.first.to(device, pin_memory)
340 self.second.to(device, pin_memory)
342 def clear_embeddings(self, embedding_names: List[str] = None):
343 self.first.clear_embeddings(embedding_names)
344 self.second.clear_embeddings(embedding_names)
346 @property
347 def embedding(self):
348 return torch.cat([self.first.embedding, self.second.embedding])
350 def __str__(self):
351 return f"DataPair:\n − First {self.first}\n − Second {self.second}\n − Labels: {self.labels}"
353 def to_plain_string(self):
354 return f"DataPair: First {self.first} || Second {self.second}"
356 def to_original_text(self):
357 return f"{self.first.to_original_text()} || {self.second.to_original_text()}"
359 def __len__(self):
360 return len(self.first) + len(self.second)
363class Token(DataPoint):
364 """
365 This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point
366 to its head in a dependency tree.
367 """
369 def __init__(
370 self,
371 text: str,
372 idx: int = None,
373 head_id: int = None,
374 whitespace_after: bool = True,
375 start_position: int = None,
376 ):
377 super().__init__()
379 self.text: str = text
380 self.idx: int = idx
381 self.head_id: int = head_id
382 self.whitespace_after: bool = whitespace_after
384 self.start_pos = start_position
385 self.end_pos = (
386 start_position + len(text) if start_position is not None else None
387 )
389 self.sentence: Sentence = None
390 self._embeddings: Dict = {}
391 self.tags_proba_dist: Dict[str, List[Label]] = {}
393 def add_tag_label(self, tag_type: str, tag: Label):
394 self.set_label(tag_type, tag.value, tag.score)
396 def add_tags_proba_dist(self, tag_type: str, tags: List[Label]):
397 self.tags_proba_dist[tag_type] = tags
399 def add_tag(self, tag_type: str, tag_value: str, confidence=1.0):
400 self.set_label(tag_type, tag_value, confidence)
402 def get_tag(self, label_type):
403 if len(self.get_labels(label_type)) == 0: return Label('')
404 return self.get_labels(label_type)[0]
406 def get_tags_proba_dist(self, tag_type: str) -> List[Label]:
407 if tag_type in self.tags_proba_dist:
408 return self.tags_proba_dist[tag_type]
409 return []
411 def get_head(self):
412 return self.sentence.get_token(self.head_id)
414 def set_embedding(self, name: str, vector: torch.tensor):
415 device = flair.device
416 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0:
417 device = next(iter(self._embeddings.values())).device
418 if device != vector.device:
419 vector = vector.to(device)
420 self._embeddings[name] = vector
422 def to(self, device: str, pin_memory: bool = False):
423 for name, vector in self._embeddings.items():
424 if str(vector.device) != str(device):
425 if pin_memory:
426 self._embeddings[name] = vector.to(
427 device, non_blocking=True
428 ).pin_memory()
429 else:
430 self._embeddings[name] = vector.to(device, non_blocking=True)
432 def clear_embeddings(self, embedding_names: List[str] = None):
433 if embedding_names is None:
434 self._embeddings: Dict = {}
435 else:
436 for name in embedding_names:
437 if name in self._embeddings.keys():
438 del self._embeddings[name]
440 def get_each_embedding(self, embedding_names: Optional[List[str]] = None) -> torch.tensor:
441 embeddings = []
442 for embed in sorted(self._embeddings.keys()):
443 if embedding_names and embed not in embedding_names: continue
444 embed = self._embeddings[embed].to(flair.device)
445 if (flair.embedding_storage_mode == "cpu") and embed.device != flair.device:
446 embed = embed.to(flair.device)
447 embeddings.append(embed)
448 return embeddings
450 def get_embedding(self, names: Optional[List[str]] = None) -> torch.tensor:
451 embeddings = self.get_each_embedding(names)
453 if embeddings:
454 return torch.cat(embeddings, dim=0)
456 return torch.tensor([], device=flair.device)
458 @property
459 def start_position(self) -> int:
460 return self.start_pos
462 @property
463 def end_position(self) -> int:
464 return self.end_pos
466 @property
467 def embedding(self):
468 return self.get_embedding()
470 def __str__(self) -> str:
471 return (
472 "Token: {} {}".format(self.idx, self.text)
473 if self.idx is not None
474 else "Token: {}".format(self.text)
475 )
477 def __repr__(self) -> str:
478 return (
479 "Token: {} {}".format(self.idx, self.text)
480 if self.idx is not None
481 else "Token: {}".format(self.text)
482 )
485class Span(DataPoint):
486 """
487 This class represents one textual span consisting of Tokens.
488 """
490 def __init__(self, tokens: List[Token]):
492 super().__init__()
494 self.tokens = tokens
495 self.start_pos = None
496 self.end_pos = None
498 if tokens:
499 self.start_pos = tokens[0].start_position
500 self.end_pos = tokens[len(tokens) - 1].end_position
502 @property
503 def text(self) -> str:
504 return " ".join([t.text for t in self.tokens])
506 def to_original_text(self) -> str:
507 pos = self.tokens[0].start_pos
508 if pos is None:
509 return " ".join([t.text for t in self.tokens])
510 str = ""
511 for t in self.tokens:
512 while t.start_pos > pos:
513 str += " "
514 pos += 1
516 str += t.text
517 pos += len(t.text)
519 return str
521 def to_plain_string(self):
522 plain = ""
523 for token in self.tokens:
524 plain += token.text
525 if token.whitespace_after:
526 plain += " "
527 return plain.rstrip()
529 def to_dict(self):
530 return {
531 "text": self.to_original_text(),
532 "start_pos": self.start_pos,
533 "end_pos": self.end_pos,
534 "labels": self.labels,
535 }
537 def __str__(self) -> str:
538 ids = ",".join([str(t.idx) for t in self.tokens])
539 label_string = " ".join([str(label) for label in self.labels])
540 labels = f' [− Labels: {label_string}]' if self.labels else ""
541 return (
542 'Span [{}]: "{}"{}'.format(ids, self.text, labels)
543 )
545 @property
546 def id_text(self) -> str:
547 return f"{' '.join([t.text for t in self.tokens])} ({','.join([str(t.idx) for t in self.tokens])})"
549 def __repr__(self) -> str:
550 ids = ",".join([str(t.idx) for t in self.tokens])
551 return (
552 '<{}-span ({}): "{}">'.format(self.tag, ids, self.text)
553 if len(self.labels) > 0
554 else '<span ({}): "{}">'.format(ids, self.text)
555 )
557 def __getitem__(self, idx: int) -> Token:
558 return self.tokens[idx]
560 def __iter__(self):
561 return iter(self.tokens)
563 def __len__(self) -> int:
564 return len(self.tokens)
566 @property
567 def tag(self):
568 return self.labels[0].value
570 @property
571 def score(self):
572 return self.labels[0].score
574 @property
575 def position_string(self):
576 return '-'.join([str(token.idx) for token in self])
579class Tokenizer(ABC):
580 r"""An abstract class representing a :class:`Tokenizer`.
582 Tokenizers are used to represent algorithms and models to split plain text into
583 individual tokens / words. All subclasses should overwrite :meth:`tokenize`, which
584 splits the given plain text into tokens. Moreover, subclasses may overwrite
585 :meth:`name`, returning a unique identifier representing the tokenizer's
586 configuration.
587 """
589 @abstractmethod
590 def tokenize(self, text: str) -> List[Token]:
591 raise NotImplementedError()
593 @property
594 def name(self) -> str:
595 return self.__class__.__name__
598class Sentence(DataPoint):
599 """
600 A Sentence is a list of tokens and is used to represent a sentence or text fragment.
601 """
603 def __init__(
604 self,
605 text: Union[str, List[str]] = None,
606 use_tokenizer: Union[bool, Tokenizer] = True,
607 language_code: str = None,
608 start_position: int = None
609 ):
610 """
611 Class to hold all meta related to a text (tokens, predictions, language code, ...)
612 :param text: original string (sentence), or a list of string tokens (words)
613 :param use_tokenizer: a custom tokenizer (default is :class:`SpaceTokenizer`)
614 more advanced options are :class:`SegTokTokenizer` to use segtok or :class:`SpacyTokenizer`
615 to use Spacy library if available). Check the implementations of abstract class Tokenizer or
616 implement your own subclass (if you need it). If instead of providing a Tokenizer, this parameter
617 is just set to True (deprecated), :class:`SegtokTokenizer` will be used.
618 :param language_code: Language of the sentence
619 :param start_position: Start char offset of the sentence in the superordinate document
620 """
621 super().__init__()
623 self.tokens: List[Token] = []
625 self._embeddings: Dict = {}
627 self.language_code: str = language_code
629 self.start_pos = start_position
630 self.end_pos = (
631 start_position + len(text) if start_position is not None else None
632 )
634 if isinstance(use_tokenizer, Tokenizer):
635 tokenizer = use_tokenizer
636 elif hasattr(use_tokenizer, "__call__"):
637 from flair.tokenization import TokenizerWrapper
638 tokenizer = TokenizerWrapper(use_tokenizer)
639 elif type(use_tokenizer) == bool:
640 from flair.tokenization import SegtokTokenizer, SpaceTokenizer
641 tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer()
642 else:
643 raise AssertionError("Unexpected type of parameter 'use_tokenizer'. " +
644 "Parameter should be bool, Callable[[str], List[Token]] (deprecated), Tokenizer")
646 # if text is passed, instantiate sentence with tokens (words)
647 if text is not None:
648 if isinstance(text, (list, tuple)):
649 [self.add_token(self._restore_windows_1252_characters(token))
650 for token in text]
651 else:
652 text = self._restore_windows_1252_characters(text)
653 [self.add_token(token) for token in tokenizer.tokenize(text)]
655 # log a warning if the dataset is empty
656 if text == "":
657 log.warning(
658 "Warning: An empty Sentence was created! Are there empty strings in your dataset?"
659 )
661 self.tokenized = None
663 # some sentences represent a document boundary (but most do not)
664 self.is_document_boundary: bool = False
666 def get_token(self, token_id: int) -> Token:
667 for token in self.tokens:
668 if token.idx == token_id:
669 return token
671 def add_token(self, token: Union[Token, str]):
673 if type(token) is str:
674 token = Token(token)
676 token.text = token.text.replace('\u200c', '')
677 token.text = token.text.replace('\u200b', '')
678 token.text = token.text.replace('\ufe0f', '')
679 token.text = token.text.replace('\ufeff', '')
681 # data with zero-width characters cannot be handled
682 if token.text == '':
683 return
685 self.tokens.append(token)
687 # set token idx if not set
688 token.sentence = self
689 if token.idx is None:
690 token.idx = len(self.tokens)
692 def get_label_names(self):
693 label_names = []
694 for label in self.labels:
695 label_names.append(label.value)
696 return label_names
698 def _add_spans_internal(self, spans: List[Span], label_type: str, min_score):
700 current_span = []
702 tags = defaultdict(lambda: 0.0)
704 previous_tag_value: str = "O"
705 for token in self:
707 tag: Label = token.get_tag(label_type)
708 tag_value = tag.value
710 # non-set tags are OUT tags
711 if tag_value == "" or tag_value == "O" or tag_value == "_":
712 tag_value = "O-"
714 # anything that is not a BIOES tag is a SINGLE tag
715 if tag_value[0:2] not in ["B-", "I-", "O-", "E-", "S-"]:
716 tag_value = "S-" + tag_value
718 # anything that is not OUT is IN
719 in_span = False
720 if tag_value[0:2] not in ["O-"]:
721 in_span = True
723 # single and begin tags start a new span
724 starts_new_span = False
725 if tag_value[0:2] in ["B-", "S-"]:
726 starts_new_span = True
728 if (
729 previous_tag_value[0:2] in ["S-"]
730 and previous_tag_value[2:] != tag_value[2:]
731 and in_span
732 ):
733 starts_new_span = True
735 if (starts_new_span or not in_span) and len(current_span) > 0:
736 scores = [t.get_labels(label_type)[0].score for t in current_span]
737 span_score = sum(scores) / len(scores)
738 if span_score > min_score:
739 span = Span(current_span)
740 span.add_label(
741 typename=label_type,
742 value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
743 score=span_score)
744 spans.append(span)
746 current_span = []
747 tags = defaultdict(lambda: 0.0)
749 if in_span:
750 current_span.append(token)
751 weight = 1.1 if starts_new_span else 1.0
752 tags[tag_value[2:]] += weight
754 # remember previous tag
755 previous_tag_value = tag_value
757 if len(current_span) > 0:
758 scores = [t.get_labels(label_type)[0].score for t in current_span]
759 span_score = sum(scores) / len(scores)
760 if span_score > min_score:
761 span = Span(current_span)
762 span.add_label(
763 typename=label_type,
764 value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
765 score=span_score)
766 spans.append(span)
768 return spans
770 def get_spans(self, label_type: Optional[str] = None, min_score=-1) -> List[Span]:
772 spans: List[Span] = []
774 # if label type is explicitly specified, get spans for this label type
775 if label_type:
776 return self._add_spans_internal(spans, label_type, min_score)
778 # else determine all label types in sentence and get all spans
779 label_types = []
780 for token in self:
781 for annotation in token.annotation_layers.keys():
782 if annotation not in label_types: label_types.append(annotation)
784 for label_type in label_types:
785 self._add_spans_internal(spans, label_type, min_score)
786 return spans
788 @property
789 def embedding(self):
790 return self.get_embedding()
792 def set_embedding(self, name: str, vector: torch.tensor):
793 device = flair.device
794 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0:
795 device = next(iter(self._embeddings.values())).device
796 if device != vector.device:
797 vector = vector.to(device)
798 self._embeddings[name] = vector
800 def get_embedding(self, names: Optional[List[str]] = None) -> torch.tensor:
801 embeddings = []
802 for embed in sorted(self._embeddings.keys()):
803 if names and embed not in names: continue
804 embedding = self._embeddings[embed]
805 embeddings.append(embedding)
807 if embeddings:
808 return torch.cat(embeddings, dim=0)
810 return torch.Tensor()
812 def to(self, device: str, pin_memory: bool = False):
814 # move sentence embeddings to device
815 for name, vector in self._embeddings.items():
816 if str(vector.device) != str(device):
817 if pin_memory:
818 self._embeddings[name] = vector.to(
819 device, non_blocking=True
820 ).pin_memory()
821 else:
822 self._embeddings[name] = vector.to(device, non_blocking=True)
824 # move token embeddings to device
825 for token in self:
826 token.to(device, pin_memory)
828 def clear_embeddings(self, embedding_names: List[str] = None):
830 # clear sentence embeddings
831 if embedding_names is None:
832 self._embeddings: Dict = {}
833 else:
834 for name in embedding_names:
835 if name in self._embeddings.keys():
836 del self._embeddings[name]
838 # clear token embeddings
839 for token in self:
840 token.clear_embeddings(embedding_names)
842 def to_tagged_string(self, main_tag=None) -> str:
843 list = []
844 for token in self.tokens:
845 list.append(token.text)
847 tags: List[str] = []
848 for label_type in token.annotation_layers.keys():
850 if main_tag is not None and main_tag != label_type:
851 continue
853 if token.get_labels(label_type)[0].value == "O":
854 continue
855 if token.get_labels(label_type)[0].value == "_":
856 continue
858 tags.append(token.get_labels(label_type)[0].value)
859 all_tags = "<" + "/".join(tags) + ">"
860 if all_tags != "<>":
861 list.append(all_tags)
862 return " ".join(list)
864 def to_tokenized_string(self) -> str:
866 if self.tokenized is None:
867 self.tokenized = " ".join([t.text for t in self.tokens])
869 return self.tokenized
871 def to_plain_string(self):
872 plain = ""
873 for token in self.tokens:
874 plain += token.text
875 if token.whitespace_after:
876 plain += " "
877 return plain.rstrip()
879 def convert_tag_scheme(self, tag_type: str = "ner", target_scheme: str = "iob"):
881 tags: List[Label] = []
882 for token in self.tokens:
883 tags.append(token.get_tag(tag_type))
885 if target_scheme == "iob":
886 iob2(tags)
888 if target_scheme == "iobes":
889 iob2(tags)
890 tags = iob_iobes(tags)
892 for index, tag in enumerate(tags):
893 self.tokens[index].set_label(tag_type, tag)
895 def infer_space_after(self):
896 """
897 Heuristics in case you wish to infer whitespace_after values for tokenized text. This is useful for some old NLP
898 tasks (such as CoNLL-03 and CoNLL-2000) that provide only tokenized data with no info of original whitespacing.
899 :return:
900 """
901 last_token = None
902 quote_count: int = 0
903 # infer whitespace after field
905 for token in self.tokens:
906 if token.text == '"':
907 quote_count += 1
908 if quote_count % 2 != 0:
909 token.whitespace_after = False
910 elif last_token is not None:
911 last_token.whitespace_after = False
913 if last_token is not None:
915 if token.text in [".", ":", ",", ";", ")", "n't", "!", "?"]:
916 last_token.whitespace_after = False
918 if token.text.startswith("'"):
919 last_token.whitespace_after = False
921 if token.text in ["("]:
922 token.whitespace_after = False
924 last_token = token
925 return self
927 def to_original_text(self) -> str:
928 if len(self.tokens) > 0 and (self.tokens[0].start_pos is None):
929 return " ".join([t.text for t in self.tokens])
930 str = ""
931 pos = 0
932 for t in self.tokens:
933 while t.start_pos > pos:
934 str += " "
935 pos += 1
937 str += t.text
938 pos += len(t.text)
940 return str
942 def to_dict(self, tag_type: str = None):
943 labels = []
944 entities = []
946 if tag_type:
947 entities = [span.to_dict() for span in self.get_spans(tag_type)]
948 if self.labels:
949 labels = [l.to_dict() for l in self.labels]
951 return {"text": self.to_original_text(), "labels": labels, "entities": entities}
953 def __getitem__(self, idx: int) -> Token:
954 return self.tokens[idx]
956 def __iter__(self):
957 return iter(self.tokens)
959 def __len__(self) -> int:
960 return len(self.tokens)
962 def __repr__(self):
963 tagged_string = self.to_tagged_string()
964 tokenized_string = self.to_tokenized_string()
966 # add Sentence labels to output if they exist
967 sentence_labels = f" − Sentence-Labels: {self.annotation_layers}" if self.annotation_layers != {} else ""
969 # add Token labels to output if they exist
970 token_labels = f' − Token-Labels: "{tagged_string}"' if tokenized_string != tagged_string else ""
972 return f'Sentence: "{tokenized_string}" [− Tokens: {len(self)}{token_labels}{sentence_labels}]'
974 def __copy__(self):
975 s = Sentence()
976 for token in self.tokens:
977 nt = Token(token.text)
978 for tag_type in token.tags:
979 nt.add_label(
980 tag_type,
981 token.get_tag(tag_type).value,
982 token.get_tag(tag_type).score,
983 )
985 s.add_token(nt)
986 return s
988 def __str__(self) -> str:
990 tagged_string = self.to_tagged_string()
991 tokenized_string = self.to_tokenized_string()
993 # add Sentence labels to output if they exist
994 sentence_labels = f" − Sentence-Labels: {self.annotation_layers}" if self.annotation_layers != {} else ""
996 # add Token labels to output if they exist
997 token_labels = f' − Token-Labels: "{tagged_string}"' if tokenized_string != tagged_string else ""
999 return f'Sentence: "{tokenized_string}" [− Tokens: {len(self)}{token_labels}{sentence_labels}]'
1001 def get_language_code(self) -> str:
1002 if self.language_code is None:
1003 import langdetect
1005 try:
1006 self.language_code = langdetect.detect(self.to_plain_string())
1007 except:
1008 self.language_code = "en"
1010 return self.language_code
1012 @staticmethod
1013 def _restore_windows_1252_characters(text: str) -> str:
1014 def to_windows_1252(match):
1015 try:
1016 return bytes([ord(match.group(0))]).decode("windows-1252")
1017 except UnicodeDecodeError:
1018 # No character at the corresponding code point: remove it
1019 return ""
1021 return re.sub(r"[\u0080-\u0099]", to_windows_1252, text)
1023 def next_sentence(self):
1024 """
1025 Get the next sentence in the document (works only if context is set through dataloader or elsewhere)
1026 :return: next Sentence in document if set, otherwise None
1027 """
1028 if '_next_sentence' in self.__dict__.keys():
1029 return self._next_sentence
1031 if '_position_in_dataset' in self.__dict__.keys():
1032 dataset = self._position_in_dataset[0]
1033 index = self._position_in_dataset[1] + 1
1034 if index < len(dataset):
1035 return dataset[index]
1037 return None
1039 def previous_sentence(self):
1040 """
1041 Get the previous sentence in the document (works only if context is set through dataloader or elsewhere)
1042 :return: previous Sentence in document if set, otherwise None
1043 """
1044 if '_previous_sentence' in self.__dict__.keys():
1045 return self._previous_sentence
1047 if '_position_in_dataset' in self.__dict__.keys():
1048 dataset = self._position_in_dataset[0]
1049 index = self._position_in_dataset[1] - 1
1050 if index >= 0:
1051 return dataset[index]
1053 return None
1055 def is_context_set(self) -> bool:
1056 """
1057 Return True or False depending on whether context is set (for instance in dataloader or elsewhere)
1058 :return: True if context is set, else False
1059 """
1060 return '_previous_sentence' in self.__dict__.keys() or '_position_in_dataset' in self.__dict__.keys()
1062 def get_labels(self, label_type: str = None):
1064 # TODO: crude hack - replace with something better
1065 if label_type:
1066 spans = self.get_spans(label_type)
1067 for span in spans:
1068 self.add_complex_label(label_type, label=SpanLabel(span, span.tag, span.score))
1070 if label_type is None:
1071 return self.labels
1073 return self.annotation_layers[label_type] if label_type in self.annotation_layers else []
1076class Image(DataPoint):
1078 def __init__(self, data=None, imageURL=None):
1079 super().__init__()
1081 self.data = data
1082 self._embeddings: Dict = {}
1083 self.imageURL = imageURL
1085 @property
1086 def embedding(self):
1087 return self.get_embedding()
1089 def __str__(self):
1091 image_repr = self.data.size() if self.data else ""
1092 image_url = self.imageURL if self.imageURL else ""
1094 return f"Image: {image_repr} {image_url}"
1096 def get_embedding(self) -> torch.tensor:
1097 embeddings = [
1098 self._embeddings[embed] for embed in sorted(self._embeddings.keys())
1099 ]
1101 if embeddings:
1102 return torch.cat(embeddings, dim=0)
1104 return torch.tensor([], device=flair.device)
1106 def set_embedding(self, name: str, vector: torch.tensor):
1107 device = flair.device
1108 if (flair.embedding_storage_mode == "cpu") and len(self._embeddings.keys()) > 0:
1109 device = next(iter(self._embeddings.values())).device
1110 if device != vector.device:
1111 vector = vector.to(device)
1112 self._embeddings[name] = vector
1114 def to(self, device: str, pin_memory: bool = False):
1115 for name, vector in self._embeddings.items():
1116 if str(vector.device) != str(device):
1117 if pin_memory:
1118 self._embeddings[name] = vector.to(
1119 device, non_blocking=True
1120 ).pin_memory()
1121 else:
1122 self._embeddings[name] = vector.to(device, non_blocking=True)
1124 def clear_embeddings(self, embedding_names: List[str] = None):
1125 if embedding_names is None:
1126 self._embeddings: Dict = {}
1127 else:
1128 for name in embedding_names:
1129 if name in self._embeddings.keys():
1130 del self._embeddings[name]
1133class FlairDataset(Dataset):
1134 @abstractmethod
1135 def is_in_memory(self) -> bool:
1136 pass
1139class Corpus:
1140 def __init__(
1141 self,
1142 train: FlairDataset = None,
1143 dev: FlairDataset = None,
1144 test: FlairDataset = None,
1145 name: str = "corpus",
1146 sample_missing_splits: Union[bool, str] = True,
1147 ):
1148 # set name
1149 self.name: str = name
1151 # abort if no data is provided
1152 if not train and not dev and not test:
1153 raise RuntimeError('No data provided when initializing corpus object.')
1155 # sample test data from train if none is provided
1156 if test is None and sample_missing_splits and train and not sample_missing_splits == 'only_dev':
1157 train_length = len(train)
1158 test_size: int = round(train_length / 10)
1159 splits = randomly_split_into_two_datasets(train, test_size)
1160 test = splits[0]
1161 train = splits[1]
1163 # sample dev data from train if none is provided
1164 if dev is None and sample_missing_splits and train and not sample_missing_splits == 'only_test':
1165 train_length = len(train)
1166 dev_size: int = round(train_length / 10)
1167 splits = randomly_split_into_two_datasets(train, dev_size)
1168 dev = splits[0]
1169 train = splits[1]
1171 # set train dev and test data
1172 self._train: FlairDataset = train
1173 self._test: FlairDataset = test
1174 self._dev: FlairDataset = dev
1176 @property
1177 def train(self) -> FlairDataset:
1178 return self._train
1180 @property
1181 def dev(self) -> FlairDataset:
1182 return self._dev
1184 @property
1185 def test(self) -> FlairDataset:
1186 return self._test
1188 def downsample(self, percentage: float = 0.1, downsample_train=True, downsample_dev=True, downsample_test=True):
1190 if downsample_train and self._train:
1191 self._train = self._downsample_to_proportion(self.train, percentage)
1193 if downsample_dev and self._dev:
1194 self._dev = self._downsample_to_proportion(self.dev, percentage)
1196 if downsample_test and self._test:
1197 self._test = self._downsample_to_proportion(self.test, percentage)
1199 return self
1201 def filter_empty_sentences(self):
1202 log.info("Filtering empty sentences")
1203 self._train = Corpus._filter_empty_sentences(self._train)
1204 self._test = Corpus._filter_empty_sentences(self._test)
1205 self._dev = Corpus._filter_empty_sentences(self._dev)
1206 log.info(self)
1208 def filter_long_sentences(self, max_charlength: int):
1209 log.info("Filtering long sentences")
1210 self._train = Corpus._filter_long_sentences(self._train, max_charlength)
1211 self._test = Corpus._filter_long_sentences(self._test, max_charlength)
1212 self._dev = Corpus._filter_long_sentences(self._dev, max_charlength)
1213 log.info(self)
1215 @staticmethod
1216 def _filter_long_sentences(dataset, max_charlength: int) -> Dataset:
1218 # find out empty sentence indices
1219 empty_sentence_indices = []
1220 non_empty_sentence_indices = []
1221 index = 0
1223 from flair.datasets import DataLoader
1225 for batch in DataLoader(dataset):
1226 for sentence in batch:
1227 if len(sentence.to_plain_string()) > max_charlength:
1228 empty_sentence_indices.append(index)
1229 else:
1230 non_empty_sentence_indices.append(index)
1231 index += 1
1233 # create subset of non-empty sentence indices
1234 subset = Subset(dataset, non_empty_sentence_indices)
1236 return subset
1238 @staticmethod
1239 def _filter_empty_sentences(dataset) -> Dataset:
1241 # find out empty sentence indices
1242 empty_sentence_indices = []
1243 non_empty_sentence_indices = []
1244 index = 0
1246 from flair.datasets import DataLoader
1248 for batch in DataLoader(dataset):
1249 for sentence in batch:
1250 if len(sentence) == 0:
1251 empty_sentence_indices.append(index)
1252 else:
1253 non_empty_sentence_indices.append(index)
1254 index += 1
1256 # create subset of non-empty sentence indices
1257 subset = Subset(dataset, non_empty_sentence_indices)
1259 return subset
1261 def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary:
1262 """
1263 Creates a dictionary of all tokens contained in the corpus.
1264 By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary.
1265 If there are more than `max_tokens` tokens in the corpus, the most frequent tokens are added first.
1266 If `min_freq` is set the a value greater than 1 only tokens occurring more than `min_freq` times are considered
1267 to be added to the dictionary.
1268 :param max_tokens: the maximum number of tokens that should be added to the dictionary (-1 = take all tokens)
1269 :param min_freq: a token needs to occur at least `min_freq` times to be added to the dictionary (-1 = there is no limitation)
1270 :return: dictionary of tokens
1271 """
1272 tokens = self._get_most_common_tokens(max_tokens, min_freq)
1274 vocab_dictionary: Dictionary = Dictionary()
1275 for token in tokens:
1276 vocab_dictionary.add_item(token)
1278 return vocab_dictionary
1280 def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]:
1281 tokens_and_frequencies = Counter(self._get_all_tokens())
1282 tokens_and_frequencies = tokens_and_frequencies.most_common()
1284 tokens = []
1285 for token, freq in tokens_and_frequencies:
1286 if (min_freq != -1 and freq < min_freq) or (
1287 max_tokens != -1 and len(tokens) == max_tokens
1288 ):
1289 break
1290 tokens.append(token)
1291 return tokens
1293 def _get_all_tokens(self) -> List[str]:
1294 tokens = list(map((lambda s: s.tokens), self.train))
1295 tokens = [token for sublist in tokens for token in sublist]
1296 return list(map((lambda t: t.text), tokens))
1298 @staticmethod
1299 def _downsample_to_proportion(dataset: Dataset, proportion: float):
1301 sampled_size: int = round(len(dataset) * proportion)
1302 splits = randomly_split_into_two_datasets(dataset, sampled_size)
1303 return splits[0]
1305 def obtain_statistics(
1306 self, label_type: str = None, pretty_print: bool = True
1307 ) -> dict:
1308 """
1309 Print statistics about the class distribution (only labels of sentences are taken into account) and sentence
1310 sizes.
1311 """
1312 json_string = {
1313 "TRAIN": self._obtain_statistics_for(self.train, "TRAIN", label_type),
1314 "TEST": self._obtain_statistics_for(self.test, "TEST", label_type),
1315 "DEV": self._obtain_statistics_for(self.dev, "DEV", label_type),
1316 }
1317 if pretty_print:
1318 import json
1320 json_string = json.dumps(json_string, indent=4)
1321 return json_string
1323 @staticmethod
1324 def _obtain_statistics_for(sentences, name, tag_type) -> dict:
1325 if len(sentences) == 0:
1326 return {}
1328 classes_to_count = Corpus._count_sentence_labels(sentences)
1329 tags_to_count = Corpus._count_token_labels(sentences, tag_type)
1330 tokens_per_sentence = Corpus._get_tokens_per_sentence(sentences)
1332 label_size_dict = {}
1333 for l, c in classes_to_count.items():
1334 label_size_dict[l] = c
1336 tag_size_dict = {}
1337 for l, c in tags_to_count.items():
1338 tag_size_dict[l] = c
1340 return {
1341 "dataset": name,
1342 "total_number_of_documents": len(sentences),
1343 "number_of_documents_per_class": label_size_dict,
1344 "number_of_tokens_per_tag": tag_size_dict,
1345 "number_of_tokens": {
1346 "total": sum(tokens_per_sentence),
1347 "min": min(tokens_per_sentence),
1348 "max": max(tokens_per_sentence),
1349 "avg": sum(tokens_per_sentence) / len(sentences),
1350 },
1351 }
1353 @staticmethod
1354 def _get_tokens_per_sentence(sentences):
1355 return list(map(lambda x: len(x.tokens), sentences))
1357 @staticmethod
1358 def _count_sentence_labels(sentences):
1359 label_count = defaultdict(lambda: 0)
1360 for sent in sentences:
1361 for label in sent.labels:
1362 label_count[label.value] += 1
1363 return label_count
1365 @staticmethod
1366 def _count_token_labels(sentences, label_type):
1367 label_count = defaultdict(lambda: 0)
1368 for sent in sentences:
1369 for token in sent.tokens:
1370 if label_type in token.annotation_layers.keys():
1371 label = token.get_tag(label_type)
1372 label_count[label.value] += 1
1373 return label_count
1375 def __str__(self) -> str:
1376 return "Corpus: %d train + %d dev + %d test sentences" % (
1377 len(self.train) if self.train else 0,
1378 len(self.dev) if self.dev else 0,
1379 len(self.test) if self.test else 0,
1380 )
1382 def make_label_dictionary(self, label_type: str) -> Dictionary:
1383 """
1384 Creates a dictionary of all labels assigned to the sentences in the corpus.
1385 :return: dictionary of labels
1386 """
1387 label_dictionary: Dictionary = Dictionary(add_unk=True)
1388 label_dictionary.multi_label = False
1390 from flair.datasets import DataLoader
1392 datasets = [self.train]
1394 data = ConcatDataset(datasets)
1396 loader = DataLoader(data, batch_size=1)
1398 log.info("Computing label dictionary. Progress:")
1400 # if there are token labels of provided type, use these. Otherwise use sentence labels
1401 token_labels_exist = False
1403 all_label_types = Counter()
1404 all_sentence_labels = []
1405 for batch in Tqdm.tqdm(iter(loader)):
1407 for sentence in batch:
1409 # check for labels of words
1410 if isinstance(sentence, Sentence):
1411 for token in sentence.tokens:
1412 all_label_types.update(token.annotation_layers.keys())
1413 for label in token.get_labels(label_type):
1414 label_dictionary.add_item(label.value)
1415 token_labels_exist = True
1417 # if we are looking for sentence-level labels
1418 if not token_labels_exist:
1419 # check if sentence itself has labels
1420 labels = sentence.get_labels(label_type)
1421 all_label_types.update(sentence.annotation_layers.keys())
1423 for label in labels:
1424 if label.value not in all_sentence_labels: all_sentence_labels.append(label.value)
1426 if not label_dictionary.multi_label:
1427 if len(labels) > 1:
1428 label_dictionary.multi_label = True
1430 # if this is not a token-level prediction problem, add sentence-level labels to dictionary
1431 if not token_labels_exist:
1432 for label in all_sentence_labels:
1433 label_dictionary.add_item(label)
1435 if len(label_dictionary.idx2item) == 0:
1436 log.error(
1437 f"Corpus contains only the labels: {', '.join([f'{label[0]} (#{label[1]})' for label in all_label_types.most_common()])}")
1438 log.error(f"You specified as label_type='{label_type}' which is not in this dataset!")
1440 raise Exception
1442 log.info(
1443 f"Corpus contains the labels: {', '.join([label[0] + f' (#{label[1]})' for label in all_label_types.most_common()])}")
1444 log.info(f"Created (for label '{label_type}') {label_dictionary}")
1446 return label_dictionary
1448 def get_label_distribution(self):
1449 class_to_count = defaultdict(lambda: 0)
1450 for sent in self.train:
1451 for label in sent.labels:
1452 class_to_count[label.value] += 1
1453 return class_to_count
1455 def get_all_sentences(self) -> Dataset:
1456 parts = []
1457 if self.train: parts.append(self.train)
1458 if self.dev: parts.append(self.dev)
1459 if self.test: parts.append(self.test)
1460 return ConcatDataset(parts)
1462 @deprecated(version="0.8", reason="Use 'make_label_dictionary' instead.")
1463 def make_tag_dictionary(self, tag_type: str) -> Dictionary:
1465 # Make the tag dictionary
1466 tag_dictionary: Dictionary = Dictionary(add_unk=False)
1467 tag_dictionary.add_item("O")
1468 for sentence in self.get_all_sentences():
1469 for token in sentence.tokens:
1470 tag_dictionary.add_item(token.get_tag(tag_type).value)
1471 tag_dictionary.add_item("<START>")
1472 tag_dictionary.add_item("<STOP>")
1473 return tag_dictionary
1476class MultiCorpus(Corpus):
1477 def __init__(self, corpora: List[Corpus], name: str = "multicorpus", **corpusargs):
1478 self.corpora: List[Corpus] = corpora
1480 train_parts = []
1481 dev_parts = []
1482 test_parts = []
1483 for corpus in self.corpora:
1484 if corpus.train: train_parts.append(corpus.train)
1485 if corpus.dev: dev_parts.append(corpus.dev)
1486 if corpus.test: test_parts.append(corpus.test)
1488 super(MultiCorpus, self).__init__(
1489 ConcatDataset(train_parts) if len(train_parts) > 0 else None,
1490 ConcatDataset(dev_parts) if len(dev_parts) > 0 else None,
1491 ConcatDataset(test_parts) if len(test_parts) > 0 else None,
1492 name=name,
1493 **corpusargs,
1494 )
1496 def __str__(self):
1497 output = f"MultiCorpus: " \
1498 f"{len(self.train) if self.train else 0} train + " \
1499 f"{len(self.dev) if self.dev else 0} dev + " \
1500 f"{len(self.test) if self.test else 0} test sentences\n - "
1501 output += "\n - ".join([f'{type(corpus).__name__} {str(corpus)} - {corpus.name}' for corpus in self.corpora])
1502 return output
1505def iob2(tags):
1506 """
1507 Check that tags have a valid IOB format.
1508 Tags in IOB1 format are converted to IOB2.
1509 """
1510 for i, tag in enumerate(tags):
1511 if tag.value == "O":
1512 continue
1513 split = tag.value.split("-")
1514 if len(split) != 2 or split[0] not in ["I", "B"]:
1515 return False
1516 if split[0] == "B":
1517 continue
1518 elif i == 0 or tags[i - 1].value == "O": # conversion IOB1 to IOB2
1519 tags[i].value = "B" + tag.value[1:]
1520 elif tags[i - 1].value[1:] == tag.value[1:]:
1521 continue
1522 else: # conversion IOB1 to IOB2
1523 tags[i].value = "B" + tag.value[1:]
1524 return True
1527def iob_iobes(tags):
1528 """
1529 IOB -> IOBES
1530 """
1531 new_tags = []
1532 for i, tag in enumerate(tags):
1533 if tag.value == "O" or tag.value == "":
1534 new_tags.append("O")
1535 elif tag.value.split("-")[0] == "B":
1536 if i + 1 != len(tags) and tags[i + 1].value.split("-")[0] == "I":
1537 new_tags.append(tag.value)
1538 else:
1539 new_tags.append(tag.value.replace("B-", "S-"))
1540 elif tag.value.split("-")[0] == "I":
1541 if i + 1 < len(tags) and tags[i + 1].value.split("-")[0] == "I":
1542 new_tags.append(tag.value)
1543 else:
1544 new_tags.append(tag.value.replace("I-", "E-"))
1545 else:
1546 raise Exception("Invalid IOB format!")
1547 return new_tags
1550def randomly_split_into_two_datasets(dataset, length_of_first):
1551 import random
1552 indices = [i for i in range(len(dataset))]
1553 random.shuffle(indices)
1555 first_dataset = indices[:length_of_first]
1556 second_dataset = indices[length_of_first:]
1557 first_dataset.sort()
1558 second_dataset.sort()
1560 return [Subset(dataset, first_dataset), Subset(dataset, second_dataset)]