Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/models/pairwise_classification_model.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

56 statements  

1from typing import Union, List 

2 

3import torch 

4 

5import flair.embeddings 

6import flair.nn 

7from flair.data import Label, DataPoint, Sentence, DataPair 

8 

9 

10class TextPairClassifier(flair.nn.DefaultClassifier): 

11 """ 

12 Text Pair Classification Model for tasks such as Recognizing Textual Entailment, build upon TextClassifier. 

13 The model takes document embeddings and puts resulting text representation(s) into a linear layer to get the 

14 actual class label. We provide two ways to embed the DataPairs: Either by embedding both DataPoints 

15 and concatenating the resulting vectors ("embed_separately=True") or by concatenating the DataPoints and embedding 

16 the resulting vector ("embed_separately=False"). 

17 """ 

18 

19 def __init__( 

20 self, 

21 document_embeddings: flair.embeddings.DocumentEmbeddings, 

22 label_type: str, 

23 embed_separately: bool = False, 

24 **classifierargs, 

25 ): 

26 """ 

27 Initializes a TextClassifier 

28 :param document_embeddings: embeddings used to embed each data point 

29 :param label_dictionary: dictionary of labels you want to predict 

30 :param multi_label: auto-detected by default, but you can set this to True to force multi-label prediction 

31 or False to force single-label prediction 

32 :param multi_label_threshold: If multi-label you can set the threshold to make predictions 

33 :param loss_weights: Dictionary of weights for labels for the loss function 

34 (if any label's weight is unspecified it will default to 1.0) 

35 """ 

36 super().__init__(**classifierargs) 

37 

38 self.document_embeddings: flair.embeddings.DocumentEmbeddings = document_embeddings 

39 

40 self._label_type = label_type 

41 

42 self.embed_separately = embed_separately 

43 

44 # if embed_separately == True the linear layer needs twice the length of the embeddings as input size 

45 # since we concatenate the embeddings of the two DataPoints in the DataPairs 

46 if self.embed_separately: 

47 self.decoder = torch.nn.Linear( 

48 2 * self.document_embeddings.embedding_length, len(self.label_dictionary) 

49 ).to(flair.device) 

50 

51 torch.nn.init.xavier_uniform_(self.decoder.weight) 

52 

53 else: 

54 # representation for both sentences 

55 self.decoder = torch.nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary)) 

56 

57 # set separator to concatenate two sentences 

58 self.sep = ' ' 

59 if isinstance(self.document_embeddings, flair.embeddings.document.TransformerDocumentEmbeddings): 

60 if self.document_embeddings.tokenizer.sep_token: 

61 self.sep = ' ' + str(self.document_embeddings.tokenizer.sep_token) + ' ' 

62 else: 

63 self.sep = ' [SEP] ' 

64 

65 torch.nn.init.xavier_uniform_(self.decoder.weight) 

66 

67 # auto-spawn on GPU if available 

68 self.to(flair.device) 

69 

70 @property 

71 def label_type(self): 

72 return self._label_type 

73 

74 def forward_pass(self, 

75 datapairs: Union[List[DataPoint], DataPoint], 

76 return_label_candidates: bool = False, 

77 ): 

78 

79 if isinstance(datapairs, DataPair): 

80 datapairs = [datapairs] 

81 

82 embedding_names = self.document_embeddings.get_names() 

83 

84 if self.embed_separately: # embed both sentences seperately, concatenate the resulting vectors 

85 first_elements = [pair.first for pair in datapairs] 

86 second_elements = [pair.second for pair in datapairs] 

87 

88 self.document_embeddings.embed(first_elements) 

89 self.document_embeddings.embed(second_elements) 

90 

91 text_embedding_list = [ 

92 torch.cat([a.get_embedding(embedding_names), b.get_embedding(embedding_names)], 0).unsqueeze(0) 

93 for (a, b) in zip(first_elements, second_elements) 

94 ] 

95 

96 else: # concatenate the sentences and embed together 

97 concatenated_sentences = [ 

98 Sentence( 

99 pair.first.to_tokenized_string() + self.sep + pair.second.to_tokenized_string(), 

100 use_tokenizer=False 

101 ) 

102 for pair in datapairs] 

103 

104 self.document_embeddings.embed(concatenated_sentences) 

105 

106 text_embedding_list = [ 

107 sentence.get_embedding(embedding_names).unsqueeze(0) for sentence in concatenated_sentences 

108 ] 

109 

110 text_embedding_tensor = torch.cat(text_embedding_list, 0).to(flair.device) 

111 

112 # linear layer 

113 scores = self.decoder(text_embedding_tensor) 

114 

115 labels = [] 

116 for pair in datapairs: 

117 labels.append([label.value for label in pair.get_labels(self.label_type)]) 

118 

119 # minimal return is scores and labels 

120 return_tuple = (scores, labels) 

121 

122 if return_label_candidates: 

123 label_candidates = [Label(value=None) for pair in datapairs] 

124 return_tuple += (datapairs, label_candidates) 

125 

126 return return_tuple 

127 

128 def _get_state_dict(self): 

129 model_state = { 

130 "state_dict": self.state_dict(), 

131 "document_embeddings": self.document_embeddings, 

132 "label_dictionary": self.label_dictionary, 

133 "label_type": self.label_type, 

134 "multi_label": self.multi_label, 

135 "multi_label_threshold": self.multi_label_threshold, 

136 "loss_weights": self.loss_weights, 

137 "embed_separately": self.embed_separately, 

138 } 

139 return model_state 

140 

141 @staticmethod 

142 def _init_model_with_state_dict(state): 

143 

144 model = TextPairClassifier( 

145 document_embeddings=state["document_embeddings"], 

146 label_dictionary=state["label_dictionary"], 

147 label_type=state["label_type"], 

148 multi_label=state["multi_label"], 

149 multi_label_threshold=0.5 if "multi_label_threshold" not in state.keys() else state["multi_label_threshold"], 

150 loss_weights=state["loss_weights"], 

151 embed_separately=state["embed_separately"], 

152 ) 

153 model.load_state_dict(state["state_dict"]) 

154 return model