Coverage for src/causalspyne/weight.py: 64%

25 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1from causalspyne.wishart import gen_weight_matrix 

2from causalspyne.utils_random import coerce_rng 

3 

4 

5class WeightGenUniform: 

6 def __init__(self, list_weight_range, prob_neg_weights=0.5, rng=None): 

7 self.list_weight_range = list_weight_range 

8 self.prob_neg_weights = prob_neg_weights 

9 self.rng = coerce_rng(rng, seed=0) 

10 

11 def gen(self, num_nodes): 

12 """ 

13 generate complete graph, fully connected 

14 """ 

15 mat_weight = self.rng.uniform( 

16 low=self.list_weight_range[0], 

17 high=self.list_weight_range[1], 

18 size=[num_nodes, num_nodes], 

19 ) 

20 

21 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k 

22 random_mask = self.rng.choice( 

23 [True, False], 

24 (num_nodes, num_nodes), 

25 (self.prob_neg_weights, 1 - self.prob_neg_weights), 

26 ) 

27 mat_weight[random_mask] *= -1 

28 if mat_weight.size == 1: 

29 return mat_weight.item() 

30 return mat_weight 

31 

32 

33class WeightGenWishart(WeightGenUniform): 

34 def __init__(self, prob_neg_weights=0.5, rng=None): 

35 self.prob_neg_weights = prob_neg_weights 

36 self.rng = coerce_rng(rng, seed=0) 

37 

38 def gen(self, num_nodes): 

39 """ 

40 generate complete graph, fully connected 

41 """ 

42 mat_weight = gen_weight_matrix(self.rng, num_nodes) 

43 

44 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k 

45 random_mask = self.rng.choice( 

46 [True, False], 

47 (num_nodes, num_nodes), 

48 (self.prob_neg_weights, 1 - self.prob_neg_weights), 

49 ) 

50 mat_weight[random_mask] *= -1 

51 if mat_weight.size == 1: 

52 return mat_weight.item() 

53 return mat_weight