Coverage for flair/flair/visual/training_curves.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1import logging 

2from collections import defaultdict 

3from pathlib import Path 

4from typing import Union, List 

5 

6import numpy as np 

7import csv 

8 

9import math 

10 

11import matplotlib.pyplot as plt 

12 

13 

14# header for 'weights.txt' 

15WEIGHT_NAME = 1 

16WEIGHT_NUMBER = 2 

17WEIGHT_VALUE = 3 

18 

19log = logging.getLogger("flair") 

20 

21 

22class Plotter(object): 

23 """ 

24 Plots training parameters (loss, f-score, and accuracy) and training weights over time. 

25 Input files are the output files 'loss.tsv' and 'weights.txt' from training either a sequence tagger or text 

26 classification model. 

27 """ 

28 

29 @staticmethod 

30 def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> dict: 

31 if type(file_name) is str: 

32 file_name = Path(file_name) 

33 

34 training_curves = { 

35 "train": {"loss": [], "score": []}, 

36 "test": {"loss": [], "score": []}, 

37 "dev": {"loss": [], "score": []}, 

38 } 

39 

40 with open(file_name, "r") as tsvin: 

41 tsvin = csv.reader(tsvin, delimiter="\t") 

42 

43 # determine the column index of loss, f-score and accuracy for train, dev and test split 

44 row = next(tsvin, None) 

45 

46 score = score.upper() 

47 

48 if f"TEST_{score}" not in row: 

49 log.warning("-" * 100) 

50 log.warning(f"WARNING: No {score} found for test split in this data.") 

51 log.warning( 

52 f"Are you sure you want to plot {score} and not another value?" 

53 ) 

54 log.warning("-" * 100) 

55 

56 TRAIN_SCORE = ( 

57 row.index(f"TRAIN_{score}") if f"TRAIN_{score}" in row else None 

58 ) 

59 DEV_SCORE = row.index(f"DEV_{score}") if f"DEV_{score}" in row else None 

60 TEST_SCORE = row.index(f"TEST_{score}") if f"TEST_{score}" in row else None 

61 

62 # then get all relevant values from the tsv 

63 for row in tsvin: 

64 

65 if TRAIN_SCORE is not None: 

66 if row[TRAIN_SCORE] != "_": 

67 training_curves["train"]["score"].append( 

68 float(row[TRAIN_SCORE]) 

69 ) 

70 

71 if DEV_SCORE is not None: 

72 if row[DEV_SCORE] != "_": 

73 training_curves["dev"]["score"].append(float(row[DEV_SCORE])) 

74 

75 if TEST_SCORE is not None: 

76 if row[TEST_SCORE] != "_": 

77 training_curves["test"]["score"].append(float(row[TEST_SCORE])) 

78 

79 return training_curves 

80 

81 @staticmethod 

82 def _extract_weight_data(file_name: Union[str, Path]) -> dict: 

83 if type(file_name) is str: 

84 file_name = Path(file_name) 

85 

86 weights = defaultdict(lambda: defaultdict(lambda: list())) 

87 

88 with open(file_name, "r") as tsvin: 

89 tsvin = csv.reader(tsvin, delimiter="\t") 

90 

91 for row in tsvin: 

92 name = row[WEIGHT_NAME] 

93 param = row[WEIGHT_NUMBER] 

94 value = float(row[WEIGHT_VALUE]) 

95 

96 weights[name][param].append(value) 

97 

98 return weights 

99 

100 @staticmethod 

101 def _extract_learning_rate(file_name: Union[str, Path]): 

102 if type(file_name) is str: 

103 file_name = Path(file_name) 

104 

105 lrs = [] 

106 losses = [] 

107 

108 with open(file_name, "r") as tsvin: 

109 tsvin = csv.reader(tsvin, delimiter="\t") 

110 row = next(tsvin, None) 

111 LEARNING_RATE = row.index("LEARNING_RATE") 

112 TRAIN_LOSS = row.index("TRAIN_LOSS") 

113 

114 # then get all relevant values from the tsv 

115 for row in tsvin: 

116 if row[TRAIN_LOSS] != "_": 

117 losses.append(float(row[TRAIN_LOSS])) 

118 if row[LEARNING_RATE] != "_": 

119 lrs.append(float(row[LEARNING_RATE])) 

120 

121 return lrs, losses 

122 

123 def plot_weights(self, file_name: Union[str, Path]): 

124 if type(file_name) is str: 

125 file_name = Path(file_name) 

126 

127 weights = self._extract_weight_data(file_name) 

128 

129 total = len(weights) 

130 columns = 2 

131 rows = max(2, int(math.ceil(total / columns))) 

132 

133 figsize = (4*columns, 3*rows) 

134 

135 fig = plt.figure() 

136 f, axarr = plt.subplots(rows, columns, figsize=figsize) 

137 

138 c = 0 

139 r = 0 

140 for name, values in weights.items(): 

141 # plot i 

142 axarr[r, c].set_title(name, fontsize=6) 

143 for _, v in values.items(): 

144 axarr[r, c].plot(np.arange(0, len(v)), v, linewidth=0.35) 

145 axarr[r, c].set_yticks([]) 

146 axarr[r, c].set_xticks([]) 

147 c += 1 

148 if c == columns: 

149 c = 0 

150 r += 1 

151 

152 while r != rows and c != columns: 

153 axarr[r, c].set_yticks([]) 

154 axarr[r, c].set_xticks([]) 

155 c += 1 

156 if c == columns: 

157 c = 0 

158 r += 1 

159 

160 # save plots 

161 f.subplots_adjust(hspace=0.5) 

162 plt.tight_layout(pad=1.0) 

163 path = file_name.parent / "weights.png" 

164 plt.savefig(path, dpi=300) 

165 print( 

166 f"Weights plots are saved in {path}" 

167 ) # to let user know the path of the save plots 

168 plt.close(fig) 

169 

170 def plot_training_curves( 

171 self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"] 

172 ): 

173 if type(file_name) is str: 

174 file_name = Path(file_name) 

175 

176 fig = plt.figure(figsize=(15, 10)) 

177 

178 for plot_no, plot_value in enumerate(plot_values): 

179 

180 training_curves = self._extract_evaluation_data(file_name, plot_value) 

181 

182 plt.subplot(len(plot_values), 1, plot_no + 1) 

183 if training_curves["train"]["score"]: 

184 x = np.arange(0, len(training_curves["train"]["score"])) 

185 plt.plot( 

186 x, training_curves["train"]["score"], label=f"training {plot_value}" 

187 ) 

188 if training_curves["dev"]["score"]: 

189 x = np.arange(0, len(training_curves["dev"]["score"])) 

190 plt.plot( 

191 x, training_curves["dev"]["score"], label=f"validation {plot_value}" 

192 ) 

193 if training_curves["test"]["score"]: 

194 x = np.arange(0, len(training_curves["test"]["score"])) 

195 plt.plot( 

196 x, training_curves["test"]["score"], label=f"test {plot_value}" 

197 ) 

198 plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0) 

199 plt.ylabel(plot_value) 

200 plt.xlabel("epochs") 

201 

202 # save plots 

203 plt.tight_layout(pad=1.0) 

204 path = file_name.parent / "training.png" 

205 plt.savefig(path, dpi=300) 

206 print( 

207 f"Loss and F1 plots are saved in {path}" 

208 ) # to let user know the path of the save plots 

209 plt.show(block=False) # to have the plots displayed when user run this module 

210 plt.close(fig) 

211 

212 def plot_learning_rate( 

213 self, file_name: Union[str, Path], skip_first: int = 10, skip_last: int = 5 

214 ): 

215 if type(file_name) is str: 

216 file_name = Path(file_name) 

217 

218 lrs, losses = self._extract_learning_rate(file_name) 

219 lrs = lrs[skip_first:-skip_last] if skip_last > 0 else lrs[skip_first:] 

220 losses = losses[skip_first:-skip_last] if skip_last > 0 else losses[skip_first:] 

221 

222 fig, ax = plt.subplots(1, 1) 

223 ax.plot(lrs, losses) 

224 ax.set_ylabel("Loss") 

225 ax.set_xlabel("Learning Rate") 

226 ax.set_xscale("log") 

227 ax.xaxis.set_major_formatter(plt.FormatStrFormatter("%.0e")) 

228 

229 # plt.show() 

230 

231 # save plot 

232 plt.tight_layout(pad=1.0) 

233 path = file_name.parent / "learning_rate.png" 

234 plt.savefig(path, dpi=300) 

235 print( 

236 f"Learning_rate plots are saved in {path}" 

237 ) # to let user know the path of the save plots 

238 plt.show(block=True) # to have the plots displayed when user run this module 

239 plt.close(fig)