Coverage for causalspyne/weight.py: 64%
25 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1from numpy.random import default_rng
3from causalspyne.wishart import gen_weight_matrix
6class WeightGenUniform:
7 def __init__(self, list_weight_range, prob_neg_weights=0.5, rng=default_rng(0)):
8 self.list_weight_range = list_weight_range
9 self.prob_neg_weights = prob_neg_weights
10 self.rng = rng
12 def gen(self, num_nodes):
13 """
14 generate complete graph, fully connected
15 """
16 mat_weight = self.rng.uniform(
17 low=self.list_weight_range[0],
18 high=self.list_weight_range[1],
19 size=[num_nodes, num_nodes],
20 )
22 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k
23 random_mask = self.rng.choice(
24 [True, False],
25 (num_nodes, num_nodes),
26 (self.prob_neg_weights, 1 - self.prob_neg_weights),
27 )
28 mat_weight[random_mask] *= -1
29 if mat_weight.size == 1:
30 return mat_weight.item()
31 return mat_weight
34class WeightGenWishart(WeightGenUniform):
35 def __init__(self, prob_neg_weights=0.5, rng=default_rng(0)):
36 self.prob_neg_weights = prob_neg_weights
37 self.rng = rng
39 def gen(self, num_nodes):
40 """
41 generate complete graph, fully connected
42 """
43 mat_weight = gen_weight_matrix(self.rng, num_nodes)
45 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k
46 random_mask = self.rng.choice(
47 [True, False],
48 (num_nodes, num_nodes),
49 (self.prob_neg_weights, 1 - self.prob_neg_weights),
50 )
51 mat_weight[random_mask] *= -1
52 if mat_weight.size == 1:
53 return mat_weight.item()
54 return mat_weight