Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/models/text_classification_model.py: 69%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from pathlib import Path
3from typing import List, Union
5import torch
6import torch.nn as nn
8import flair.embeddings
9import flair.nn
10from flair.data import Label, DataPoint
11from flair.file_utils import cached_path
13log = logging.getLogger("flair")
16class TextClassifier(flair.nn.DefaultClassifier):
17 """
18 Text Classification Model
19 The model takes word embeddings, puts them into an RNN to obtain a text representation, and puts the
20 text representation in the end into a linear layer to get the actual class label.
21 The model can handle single and multi class data sets.
22 """
24 def __init__(
25 self,
26 document_embeddings: flair.embeddings.DocumentEmbeddings,
27 label_type: str,
28 **classifierargs,
29 ):
30 """
31 Initializes a TextClassifier
32 :param document_embeddings: embeddings used to embed each data point
33 :param label_dictionary: dictionary of labels you want to predict
34 :param multi_label: auto-detected by default, but you can set this to True to force multi-label prediction
35 or False to force single-label prediction
36 :param multi_label_threshold: If multi-label you can set the threshold to make predictions
37 :param beta: Parameter for F-beta score for evaluation and training annealing
38 :param loss_weights: Dictionary of weights for labels for the loss function
39 (if any label's weight is unspecified it will default to 1.0)
40 """
42 super(TextClassifier, self).__init__(**classifierargs)
44 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings
46 self._label_type = label_type
48 self.decoder = nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary))
49 nn.init.xavier_uniform_(self.decoder.weight)
51 # auto-spawn on GPU if available
52 self.to(flair.device)
54 def forward_pass(self,
55 sentences: Union[List[DataPoint], DataPoint],
56 return_label_candidates: bool = False,
57 ):
59 # embed sentences
60 self.document_embeddings.embed(sentences)
62 # make tensor for all embedded sentences in batch
63 embedding_names = self.document_embeddings.get_names()
64 text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences]
65 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device)
67 # send through decoder to get logits
68 scores = self.decoder(text_embedding_tensor)
70 labels = []
71 for sentence in sentences:
72 labels.append([label.value for label in sentence.get_labels(self.label_type)])
74 # minimal return is scores and labels
75 return_tuple = (scores, labels)
77 if return_label_candidates:
78 label_candidates = [Label(value=None) for sentence in sentences]
79 return_tuple += (sentences, label_candidates)
81 return return_tuple
83 def _get_state_dict(self):
84 model_state = {
85 "state_dict": self.state_dict(),
86 "document_embeddings": self.document_embeddings,
87 "label_dictionary": self.label_dictionary,
88 "label_type": self.label_type,
89 "multi_label": self.multi_label,
90 "multi_label_threshold": self.multi_label_threshold,
91 "weight_dict": self.weight_dict,
92 }
93 return model_state
95 @staticmethod
96 def _init_model_with_state_dict(state):
97 weights = None if "weight_dict" not in state.keys() else state["weight_dict"]
98 label_type = None if "label_type" not in state.keys() else state["label_type"]
100 model = TextClassifier(
101 document_embeddings=state["document_embeddings"],
102 label_dictionary=state["label_dictionary"],
103 label_type=label_type,
104 multi_label=state["multi_label"],
105 multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"],
106 loss_weights=weights,
107 )
108 model.load_state_dict(state["state_dict"])
109 return model
111 @staticmethod
112 def _fetch_model(model_name) -> str:
114 model_map = {}
115 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
117 model_map["de-offensive-language"] = "/".join(
118 [hu_path, "de-offensive-language", "germ-eval-2018-task-1-v0.8.pt"]
119 )
121 # English sentiment models
122 model_map["sentiment"] = "/".join(
123 [hu_path, "sentiment-curated-distilbert", "sentiment-en-mix-distillbert_4.pt"]
124 )
125 model_map["en-sentiment"] = "/".join(
126 [hu_path, "sentiment-curated-distilbert", "sentiment-en-mix-distillbert_4.pt"]
127 )
128 model_map["sentiment-fast"] = "/".join(
129 [hu_path, "sentiment-curated-fasttext-rnn", "sentiment-en-mix-ft-rnn_v8.pt"]
130 )
132 # Communicative Functions Model
133 model_map["communicative-functions"] = "/".join(
134 [hu_path, "comfunc", "communicative-functions.pt"]
135 )
137 cache_dir = Path("models")
138 if model_name in model_map:
139 model_name = cached_path(model_map[model_name], cache_dir=cache_dir)
141 return model_name
143 @property
144 def label_type(self):
145 return self._label_type