Coverage for flair/flair/models/text_classification_model.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1import logging 

2from pathlib import Path 

3from typing import List, Union 

4 

5import torch 

6import torch.nn as nn 

7 

8import flair.embeddings 

9import flair.nn 

10from flair.data import Label, DataPoint 

11from flair.file_utils import cached_path 

12 

13log = logging.getLogger("flair") 

14 

15 

16class TextClassifier(flair.nn.DefaultClassifier): 

17 """ 

18 Text Classification Model 

19 The model takes word embeddings, puts them into an RNN to obtain a text representation, and puts the 

20 text representation in the end into a linear layer to get the actual class label. 

21 The model can handle single and multi class data sets. 

22 """ 

23 

24 def __init__( 

25 self, 

26 document_embeddings: flair.embeddings.DocumentEmbeddings, 

27 label_type: str, 

28 **classifierargs, 

29 ): 

30 """ 

31 Initializes a TextClassifier 

32 :param document_embeddings: embeddings used to embed each data point 

33 :param label_dictionary: dictionary of labels you want to predict 

34 :param multi_label: auto-detected by default, but you can set this to True to force multi-label prediction 

35 or False to force single-label prediction 

36 :param multi_label_threshold: If multi-label you can set the threshold to make predictions 

37 :param beta: Parameter for F-beta score for evaluation and training annealing 

38 :param loss_weights: Dictionary of weights for labels for the loss function 

39 (if any label's weight is unspecified it will default to 1.0) 

40 """ 

41 

42 super(TextClassifier, self).__init__(**classifierargs) 

43 

44 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings 

45 

46 self._label_type = label_type 

47 

48 self.decoder = nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary)) 

49 nn.init.xavier_uniform_(self.decoder.weight) 

50 

51 # auto-spawn on GPU if available 

52 self.to(flair.device) 

53 

54 def forward_pass(self, 

55 sentences: Union[List[DataPoint], DataPoint], 

56 return_label_candidates: bool = False, 

57 ): 

58 

59 # embed sentences 

60 self.document_embeddings.embed(sentences) 

61 

62 # make tensor for all embedded sentences in batch 

63 embedding_names = self.document_embeddings.get_names() 

64 text_embedding_list = [sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in sentences] 

65 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device) 

66 

67 # send through decoder to get logits 

68 scores = self.decoder(text_embedding_tensor) 

69 

70 labels = [] 

71 for sentence in sentences: 

72 labels.append([label.value for label in sentence.get_labels(self.label_type)]) 

73 

74 # minimal return is scores and labels 

75 return_tuple = (scores, labels) 

76 

77 if return_label_candidates: 

78 label_candidates = [Label(value=None) for sentence in sentences] 

79 return_tuple += (sentences, label_candidates) 

80 

81 return return_tuple 

82 

83 def _get_state_dict(self): 

84 model_state = { 

85 "state_dict": self.state_dict(), 

86 "document_embeddings": self.document_embeddings, 

87 "label_dictionary": self.label_dictionary, 

88 "label_type": self.label_type, 

89 "multi_label": self.multi_label, 

90 "multi_label_threshold": self.multi_label_threshold, 

91 "weight_dict": self.weight_dict, 

92 } 

93 return model_state 

94 

95 @staticmethod 

96 def _init_model_with_state_dict(state): 

97 weights = None if "weight_dict" not in state.keys() else state["weight_dict"] 

98 label_type = None if "label_type" not in state.keys() else state["label_type"] 

99 

100 model = TextClassifier( 

101 document_embeddings=state["document_embeddings"], 

102 label_dictionary=state["label_dictionary"], 

103 label_type=label_type, 

104 multi_label=state["multi_label"], 

105 multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"], 

106 loss_weights=weights, 

107 ) 

108 model.load_state_dict(state["state_dict"]) 

109 return model 

110 

111 @staticmethod 

112 def _fetch_model(model_name) -> str: 

113 

114 model_map = {} 

115 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" 

116 

117 model_map["de-offensive-language"] = "/".join( 

118 [hu_path, "de-offensive-language", "germ-eval-2018-task-1-v0.8.pt"] 

119 ) 

120 

121 # English sentiment models 

122 model_map["sentiment"] = "/".join( 

123 [hu_path, "sentiment-curated-distilbert", "sentiment-en-mix-distillbert_4.pt"] 

124 ) 

125 model_map["en-sentiment"] = "/".join( 

126 [hu_path, "sentiment-curated-distilbert", "sentiment-en-mix-distillbert_4.pt"] 

127 ) 

128 model_map["sentiment-fast"] = "/".join( 

129 [hu_path, "sentiment-curated-fasttext-rnn", "sentiment-en-mix-ft-rnn_v8.pt"] 

130 ) 

131 

132 # Communicative Functions Model 

133 model_map["communicative-functions"] = "/".join( 

134 [hu_path, "comfunc", "communicative-functions.pt"] 

135 ) 

136 

137 cache_dir = Path("models") 

138 if model_name in model_map: 

139 model_name = cached_path(model_map[model_name], cache_dir=cache_dir) 

140 

141 return model_name 

142 

143 @property 

144 def label_type(self): 

145 return self._label_type