Coverage for causalspyne/weight.py: 64%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1from numpy.random import default_rng 

2 

3from causalspyne.wishart import gen_weight_matrix 

4 

5 

6class WeightGenUniform: 

7 def __init__(self, list_weight_range, prob_neg_weights=0.5, rng=default_rng(0)): 

8 self.list_weight_range = list_weight_range 

9 self.prob_neg_weights = prob_neg_weights 

10 self.rng = rng 

11 

12 def gen(self, num_nodes): 

13 """ 

14 generate complete graph, fully connected 

15 """ 

16 mat_weight = self.rng.uniform( 

17 low=self.list_weight_range[0], 

18 high=self.list_weight_range[1], 

19 size=[num_nodes, num_nodes], 

20 ) 

21 

22 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k 

23 random_mask = self.rng.choice( 

24 [True, False], 

25 (num_nodes, num_nodes), 

26 (self.prob_neg_weights, 1 - self.prob_neg_weights), 

27 ) 

28 mat_weight[random_mask] *= -1 

29 if mat_weight.size == 1: 

30 return mat_weight.item() 

31 return mat_weight 

32 

33 

34class WeightGenWishart(WeightGenUniform): 

35 def __init__(self, prob_neg_weights=0.5, rng=default_rng(0)): 

36 self.prob_neg_weights = prob_neg_weights 

37 self.rng = rng 

38 

39 def gen(self, num_nodes): 

40 """ 

41 generate complete graph, fully connected 

42 """ 

43 mat_weight = gen_weight_matrix(self.rng, num_nodes) 

44 

45 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k 

46 random_mask = self.rng.choice( 

47 [True, False], 

48 (num_nodes, num_nodes), 

49 (self.prob_neg_weights, 1 - self.prob_neg_weights), 

50 ) 

51 mat_weight[random_mask] *= -1 

52 if mat_weight.size == 1: 

53 return mat_weight.item() 

54 return mat_weight