Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/samplers.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1import logging 

2from collections import defaultdict 

3 

4from torch.utils.data.sampler import Sampler 

5import random, torch 

6 

7from flair.data import FlairDataset 

8 

9log = logging.getLogger("flair") 

10 

11 

12class FlairSampler(Sampler): 

13 def set_dataset(self, data_source): 

14 """Initialize by passing a block_size and a plus_window parameter. 

15 :param data_source: dataset to sample from 

16 """ 

17 self.data_source = data_source 

18 self.num_samples = len(self.data_source) 

19 

20 def __len__(self): 

21 return self.num_samples 

22 

23 

24class ImbalancedClassificationDatasetSampler(FlairSampler): 

25 """Use this to upsample rare classes and downsample common classes in your unbalanced classification dataset. 

26 """ 

27 

28 def __init__(self): 

29 super(ImbalancedClassificationDatasetSampler, self).__init__(None) 

30 

31 def set_dataset(self, data_source: FlairDataset): 

32 """ 

33 Initialize by passing a classification dataset with labels, i.e. either TextClassificationDataSet or 

34 :param data_source: 

35 """ 

36 self.data_source = data_source 

37 self.num_samples = len(self.data_source) 

38 self.indices = list(range(len(data_source))) 

39 

40 # first determine the distribution of classes in the dataset 

41 label_count = defaultdict(int) 

42 for sentence in data_source: 

43 for label in sentence.labels: 

44 label_count[label.value] += 1 

45 

46 # weight for each sample 

47 offset = 0 

48 weights = [ 

49 1.0 / (offset + label_count[data_source[idx].labels[0].value]) 

50 for idx in self.indices 

51 ] 

52 

53 self.weights = torch.DoubleTensor(weights) 

54 

55 def __iter__(self): 

56 return ( 

57 self.indices[i] 

58 for i in torch.multinomial(self.weights, self.num_samples, replacement=True) 

59 ) 

60 

61 

62class ChunkSampler(FlairSampler): 

63 """Splits data into blocks and randomizes them before sampling. This causes some order of the data to be preserved, 

64 while still shuffling the data. 

65 """ 

66 

67 def __init__(self, block_size=5, plus_window=5): 

68 super(ChunkSampler, self).__init__(None) 

69 self.block_size = block_size 

70 self.plus_window = plus_window 

71 self.data_source = None 

72 

73 def __iter__(self): 

74 data = list(range(len(self.data_source))) 

75 

76 blocksize = self.block_size + random.randint(0, self.plus_window) 

77 

78 log.info( 

79 f"Chunk sampling with blocksize = {blocksize} ({self.block_size} + {self.plus_window})" 

80 ) 

81 

82 # Create blocks 

83 blocks = [data[i : i + blocksize] for i in range(0, len(data), blocksize)] 

84 # shuffle the blocks 

85 random.shuffle(blocks) 

86 # concatenate the shuffled blocks 

87 data[:] = [b for bs in blocks for b in bs] 

88 return iter(data) 

89 

90 

91class ExpandingChunkSampler(FlairSampler): 

92 """Splits data into blocks and randomizes them before sampling. Block size grows with each epoch. 

93 This causes some order of the data to be preserved, while still shuffling the data. 

94 """ 

95 

96 def __init__(self, step=3): 

97 """Initialize by passing a block_size and a plus_window parameter. 

98 :param data_source: dataset to sample from 

99 """ 

100 super(ExpandingChunkSampler, self).__init__(None) 

101 self.block_size = 1 

102 self.epoch_count = 0 

103 self.step = step 

104 

105 def __iter__(self): 

106 self.epoch_count += 1 

107 

108 data = list(range(len(self.data_source))) 

109 

110 log.info(f"Chunk sampling with blocksize = {self.block_size}") 

111 

112 # Create blocks 

113 blocks = [ 

114 data[i : i + self.block_size] for i in range(0, len(data), self.block_size) 

115 ] 

116 # shuffle the blocks 

117 random.shuffle(blocks) 

118 # concatenate the shuffled blocks 

119 data[:] = [b for bs in blocks for b in bs] 

120 

121 if self.epoch_count % self.step == 0: 

122 self.block_size += 1 

123 

124 return iter(data)