Coverage for flair/flair/visual/training_curves.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from collections import defaultdict
3from pathlib import Path
4from typing import Union, List
6import numpy as np
7import csv
9import math
11import matplotlib.pyplot as plt
14# header for 'weights.txt'
15WEIGHT_NAME = 1
16WEIGHT_NUMBER = 2
17WEIGHT_VALUE = 3
19log = logging.getLogger("flair")
22class Plotter(object):
23 """
24 Plots training parameters (loss, f-score, and accuracy) and training weights over time.
25 Input files are the output files 'loss.tsv' and 'weights.txt' from training either a sequence tagger or text
26 classification model.
27 """
29 @staticmethod
30 def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> dict:
31 if type(file_name) is str:
32 file_name = Path(file_name)
34 training_curves = {
35 "train": {"loss": [], "score": []},
36 "test": {"loss": [], "score": []},
37 "dev": {"loss": [], "score": []},
38 }
40 with open(file_name, "r") as tsvin:
41 tsvin = csv.reader(tsvin, delimiter="\t")
43 # determine the column index of loss, f-score and accuracy for train, dev and test split
44 row = next(tsvin, None)
46 score = score.upper()
48 if f"TEST_{score}" not in row:
49 log.warning("-" * 100)
50 log.warning(f"WARNING: No {score} found for test split in this data.")
51 log.warning(
52 f"Are you sure you want to plot {score} and not another value?"
53 )
54 log.warning("-" * 100)
56 TRAIN_SCORE = (
57 row.index(f"TRAIN_{score}") if f"TRAIN_{score}" in row else None
58 )
59 DEV_SCORE = row.index(f"DEV_{score}") if f"DEV_{score}" in row else None
60 TEST_SCORE = row.index(f"TEST_{score}") if f"TEST_{score}" in row else None
62 # then get all relevant values from the tsv
63 for row in tsvin:
65 if TRAIN_SCORE is not None:
66 if row[TRAIN_SCORE] != "_":
67 training_curves["train"]["score"].append(
68 float(row[TRAIN_SCORE])
69 )
71 if DEV_SCORE is not None:
72 if row[DEV_SCORE] != "_":
73 training_curves["dev"]["score"].append(float(row[DEV_SCORE]))
75 if TEST_SCORE is not None:
76 if row[TEST_SCORE] != "_":
77 training_curves["test"]["score"].append(float(row[TEST_SCORE]))
79 return training_curves
81 @staticmethod
82 def _extract_weight_data(file_name: Union[str, Path]) -> dict:
83 if type(file_name) is str:
84 file_name = Path(file_name)
86 weights = defaultdict(lambda: defaultdict(lambda: list()))
88 with open(file_name, "r") as tsvin:
89 tsvin = csv.reader(tsvin, delimiter="\t")
91 for row in tsvin:
92 name = row[WEIGHT_NAME]
93 param = row[WEIGHT_NUMBER]
94 value = float(row[WEIGHT_VALUE])
96 weights[name][param].append(value)
98 return weights
100 @staticmethod
101 def _extract_learning_rate(file_name: Union[str, Path]):
102 if type(file_name) is str:
103 file_name = Path(file_name)
105 lrs = []
106 losses = []
108 with open(file_name, "r") as tsvin:
109 tsvin = csv.reader(tsvin, delimiter="\t")
110 row = next(tsvin, None)
111 LEARNING_RATE = row.index("LEARNING_RATE")
112 TRAIN_LOSS = row.index("TRAIN_LOSS")
114 # then get all relevant values from the tsv
115 for row in tsvin:
116 if row[TRAIN_LOSS] != "_":
117 losses.append(float(row[TRAIN_LOSS]))
118 if row[LEARNING_RATE] != "_":
119 lrs.append(float(row[LEARNING_RATE]))
121 return lrs, losses
123 def plot_weights(self, file_name: Union[str, Path]):
124 if type(file_name) is str:
125 file_name = Path(file_name)
127 weights = self._extract_weight_data(file_name)
129 total = len(weights)
130 columns = 2
131 rows = max(2, int(math.ceil(total / columns)))
133 figsize = (4*columns, 3*rows)
135 fig = plt.figure()
136 f, axarr = plt.subplots(rows, columns, figsize=figsize)
138 c = 0
139 r = 0
140 for name, values in weights.items():
141 # plot i
142 axarr[r, c].set_title(name, fontsize=6)
143 for _, v in values.items():
144 axarr[r, c].plot(np.arange(0, len(v)), v, linewidth=0.35)
145 axarr[r, c].set_yticks([])
146 axarr[r, c].set_xticks([])
147 c += 1
148 if c == columns:
149 c = 0
150 r += 1
152 while r != rows and c != columns:
153 axarr[r, c].set_yticks([])
154 axarr[r, c].set_xticks([])
155 c += 1
156 if c == columns:
157 c = 0
158 r += 1
160 # save plots
161 f.subplots_adjust(hspace=0.5)
162 plt.tight_layout(pad=1.0)
163 path = file_name.parent / "weights.png"
164 plt.savefig(path, dpi=300)
165 print(
166 f"Weights plots are saved in {path}"
167 ) # to let user know the path of the save plots
168 plt.close(fig)
170 def plot_training_curves(
171 self, file_name: Union[str, Path], plot_values: List[str] = ["loss", "F1"]
172 ):
173 if type(file_name) is str:
174 file_name = Path(file_name)
176 fig = plt.figure(figsize=(15, 10))
178 for plot_no, plot_value in enumerate(plot_values):
180 training_curves = self._extract_evaluation_data(file_name, plot_value)
182 plt.subplot(len(plot_values), 1, plot_no + 1)
183 if training_curves["train"]["score"]:
184 x = np.arange(0, len(training_curves["train"]["score"]))
185 plt.plot(
186 x, training_curves["train"]["score"], label=f"training {plot_value}"
187 )
188 if training_curves["dev"]["score"]:
189 x = np.arange(0, len(training_curves["dev"]["score"]))
190 plt.plot(
191 x, training_curves["dev"]["score"], label=f"validation {plot_value}"
192 )
193 if training_curves["test"]["score"]:
194 x = np.arange(0, len(training_curves["test"]["score"]))
195 plt.plot(
196 x, training_curves["test"]["score"], label=f"test {plot_value}"
197 )
198 plt.legend(bbox_to_anchor=(1.04, 0), loc="lower left", borderaxespad=0)
199 plt.ylabel(plot_value)
200 plt.xlabel("epochs")
202 # save plots
203 plt.tight_layout(pad=1.0)
204 path = file_name.parent / "training.png"
205 plt.savefig(path, dpi=300)
206 print(
207 f"Loss and F1 plots are saved in {path}"
208 ) # to let user know the path of the save plots
209 plt.show(block=False) # to have the plots displayed when user run this module
210 plt.close(fig)
212 def plot_learning_rate(
213 self, file_name: Union[str, Path], skip_first: int = 10, skip_last: int = 5
214 ):
215 if type(file_name) is str:
216 file_name = Path(file_name)
218 lrs, losses = self._extract_learning_rate(file_name)
219 lrs = lrs[skip_first:-skip_last] if skip_last > 0 else lrs[skip_first:]
220 losses = losses[skip_first:-skip_last] if skip_last > 0 else losses[skip_first:]
222 fig, ax = plt.subplots(1, 1)
223 ax.plot(lrs, losses)
224 ax.set_ylabel("Loss")
225 ax.set_xlabel("Learning Rate")
226 ax.set_xscale("log")
227 ax.xaxis.set_major_formatter(plt.FormatStrFormatter("%.0e"))
229 # plt.show()
231 # save plot
232 plt.tight_layout(pad=1.0)
233 path = file_name.parent / "learning_rate.png"
234 plt.savefig(path, dpi=300)
235 print(
236 f"Learning_rate plots are saved in {path}"
237 ) # to let user know the path of the save plots
238 plt.show(block=True) # to have the plots displayed when user run this module
239 plt.close(fig)