Coverage for sacred/sacred/randomness.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

31 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import random 

5 

6import sacred.optional as opt 

7from sacred.settings import SETTINGS 

8from sacred.utils import module_is_in_cache 

9 

10SEEDRANGE = (1, int(1e9)) 

11 

12 

13def get_seed(rnd=None): 

14 if rnd is None: 

15 return random.randint(*SEEDRANGE) 

16 else: 

17 try: 

18 return rnd.integers(*SEEDRANGE, dtype=int) 

19 except Exception: 

20 return rnd.randint(*SEEDRANGE) 

21 

22 

23def create_rnd(seed): 

24 assert isinstance(seed, int), "Seed has to be integer but was {} {}".format( 

25 repr(seed), type(seed) 

26 ) 

27 if opt.has_numpy: 

28 if SETTINGS.CONFIG.NUMPY_RANDOM_LEGACY_API: 

29 return opt.np.random.RandomState(seed) 

30 else: 

31 return opt.np.random.default_rng(seed) 

32 else: 

33 return random.Random(seed) 

34 

35 

36def set_global_seed(seed): 

37 random.seed(seed) 

38 if opt.has_numpy: 

39 opt.np.random.seed(seed) 

40 if module_is_in_cache("tensorflow"): 

41 tf = opt.get_tensorflow() 

42 tf.set_random_seed(seed) 

43 if module_is_in_cache("torch"): 

44 import torch 

45 

46 torch.manual_seed(seed) 

47 if torch.cuda.is_available(): 

48 torch.cuda.manual_seed_all(seed)