Coverage for src/causalspyne/weight.py: 64%
25 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1from causalspyne.wishart import gen_weight_matrix
2from causalspyne.utils_random import coerce_rng
5class WeightGenUniform:
6 def __init__(self, list_weight_range, prob_neg_weights=0.5, rng=None):
7 self.list_weight_range = list_weight_range
8 self.prob_neg_weights = prob_neg_weights
9 self.rng = coerce_rng(rng, seed=0)
11 def gen(self, num_nodes):
12 """
13 generate complete graph, fully connected
14 """
15 mat_weight = self.rng.uniform(
16 low=self.list_weight_range[0],
17 high=self.list_weight_range[1],
18 size=[num_nodes, num_nodes],
19 )
21 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k
22 random_mask = self.rng.choice(
23 [True, False],
24 (num_nodes, num_nodes),
25 (self.prob_neg_weights, 1 - self.prob_neg_weights),
26 )
27 mat_weight[random_mask] *= -1
28 if mat_weight.size == 1:
29 return mat_weight.item()
30 return mat_weight
33class WeightGenWishart(WeightGenUniform):
34 def __init__(self, prob_neg_weights=0.5, rng=None):
35 self.prob_neg_weights = prob_neg_weights
36 self.rng = coerce_rng(rng, seed=0)
38 def gen(self, num_nodes):
39 """
40 generate complete graph, fully connected
41 """
42 mat_weight = gen_weight_matrix(self.rng, num_nodes)
44 # set some edges randomly to negative: e.g. x_i = 2x_j - 3x_k
45 random_mask = self.rng.choice(
46 [True, False],
47 (num_nodes, num_nodes),
48 (self.prob_neg_weights, 1 - self.prob_neg_weights),
49 )
50 mat_weight[random_mask] *= -1
51 if mat_weight.size == 1:
52 return mat_weight.item()
53 return mat_weight