Coverage for src/causalspyne/dag_gen.py: 100%
41 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2concrete class to generate simple DAGs
4"""
6import warnings
8from causalspyne.erdo_renyi_plp import Erdos_Renyi_PLP
9from causalspyne.dag_interface import MatDAG
10from causalspyne.weight import WeightGenWishart
11from causalspyne.dag_manipulator import DAGManipulator
12from causalspyne.utils_random import coerce_rng
15class GenDAG:
16 def __init__(self, num_nodes, degree, obj_gen_weight=None,
17 rng=None):
18 """
19 degree: expected degree for each node
20 """
21 rng = coerce_rng(rng)
22 self.num_nodes = num_nodes
23 self.degree = degree
24 self.strategy_gen_dag = Erdos_Renyi_PLP(rng)
25 self.obj_gen_weight = obj_gen_weight
26 if obj_gen_weight is None:
27 self.obj_gen_weight = WeightGenWishart(rng=rng)
28 self.dag_manipulator = None
29 self.rng = rng
31 def gen_dag(self, num_nodes=None, prefix="", *, target_num_confounder):
32 """
33 generate DAG and wrap it around with interface
34 """
35 if num_nodes is None:
36 num_nodes = self.num_nodes
37 mat_skeleton = self.strategy_gen_dag(num_nodes, self.degree)
39 mat_mask = (mat_skeleton != 0).astype(float)
40 mat_weight = self.obj_gen_weight.gen(num_nodes)
41 # Hardarmard product
42 mat_weighted_adjacency = mat_mask * mat_weight
44 dag = MatDAG(mat_weighted_adjacency, name_prefix=prefix, rng=self.rng)
45 self.dag_manipulator = DAGManipulator(dag,
46 self.obj_gen_weight, self.rng)
47 ind_arbitrary = dag.get_top_last()
48 counter = 0
49 for _ in range(dag.num_nodes):
50 flag_success = self.dag_manipulator.mk_confound(
51 ind_arbitrary_confound_input=ind_arbitrary)
52 if not flag_success:
53 counter += 1
54 if dag.num_confounder >= target_num_confounder:
55 break
56 # FIXME: it can be the new ind_arbitrary has been tried out already
57 ind_arbitrary = dag.climb(ind_arbitrary)
58 if ind_arbitrary is None:
59 break
60 num_confounder = len(dag.list_confounder)
61 if num_confounder < target_num_confounder and \
62 dag.num_nodes - target_num_confounder > 1:
63 warnings.warn(
64 f"\n failed to ensure {target_num_confounder} confounders for \
65 adjacency matrix \n{dag.mat_adjacency}, \
66 \n after {counter} failed trials, \
67 \n{num_confounder} confounders only")
68 return dag