Coverage for flair/flair/samplers.py: 33%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from collections import defaultdict
4from torch.utils.data.sampler import Sampler
5import random, torch
7from flair.data import FlairDataset
9log = logging.getLogger("flair")
12class FlairSampler(Sampler):
13 def set_dataset(self, data_source):
14 """Initialize by passing a block_size and a plus_window parameter.
15 :param data_source: dataset to sample from
16 """
17 self.data_source = data_source
18 self.num_samples = len(self.data_source)
20 def __len__(self):
21 return self.num_samples
24class ImbalancedClassificationDatasetSampler(FlairSampler):
25 """Use this to upsample rare classes and downsample common classes in your unbalanced classification dataset.
26 """
28 def __init__(self):
29 super(ImbalancedClassificationDatasetSampler, self).__init__(None)
31 def set_dataset(self, data_source: FlairDataset):
32 """
33 Initialize by passing a classification dataset with labels, i.e. either TextClassificationDataSet or
34 :param data_source:
35 """
36 self.data_source = data_source
37 self.num_samples = len(self.data_source)
38 self.indices = list(range(len(data_source)))
40 # first determine the distribution of classes in the dataset
41 label_count = defaultdict(int)
42 for sentence in data_source:
43 for label in sentence.labels:
44 label_count[label.value] += 1
46 # weight for each sample
47 offset = 0
48 weights = [
49 1.0 / (offset + label_count[data_source[idx].labels[0].value])
50 for idx in self.indices
51 ]
53 self.weights = torch.DoubleTensor(weights)
55 def __iter__(self):
56 return (
57 self.indices[i]
58 for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
59 )
62class ChunkSampler(FlairSampler):
63 """Splits data into blocks and randomizes them before sampling. This causes some order of the data to be preserved,
64 while still shuffling the data.
65 """
67 def __init__(self, block_size=5, plus_window=5):
68 super(ChunkSampler, self).__init__(None)
69 self.block_size = block_size
70 self.plus_window = plus_window
71 self.data_source = None
73 def __iter__(self):
74 data = list(range(len(self.data_source)))
76 blocksize = self.block_size + random.randint(0, self.plus_window)
78 log.info(
79 f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})"
80 )
82 # Create blocks
83 blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)]
84 # shuffle the blocks
85 random.shuffle(blocks)
86 # concatenate the shuffled blocks
87 data[:] = [b for bs in blocks for b in bs]
88 return iter(data)
91class ExpandingChunkSampler(FlairSampler):
92 """Splits data into blocks and randomizes them before sampling. Block size grows with each epoch.
93 This causes some order of the data to be preserved, while still shuffling the data.
94 """
96 def __init__(self, step=3):
97 """Initialize by passing a block_size and a plus_window parameter.
98 :param data_source: dataset to sample from
99 """
100 super(ExpandingChunkSampler, self).__init__(None)
101 self.block_size = 1
102 self.epoch_count = 0
103 self.step = step
105 def __iter__(self):
106 self.epoch_count += 1
108 data = list(range(len(self.data_source)))
110 log.info(f"Chunk sampling with blocksize = {self.block_size}")
112 # Create blocks
113 blocks = [
114 data[i : i + self.block_size] for i in range(0, len(data), self.block_size)
115 ]
116 # shuffle the blocks
117 random.shuffle(blocks)
118 # concatenate the shuffled blocks
119 data[:] = [b for bs in blocks for b in bs]
121 if self.epoch_count % self.step == 0:
122 self.block_size += 1
124 return iter(data)