Coverage for src/causalspyne/gen_dag_2level.py: 100%
52 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
22-level DAG generation
3"""
5from causalspyne.dag_stack_indexer import DAGStackIndexer
6from causalspyne.dag_manipulator import DAGManipulator
7from causalspyne.weight import WeightGenWishart
8from causalspyne.utils_random import coerce_rng
11class GenDAG2Level:
12 """
13 generate a DAG with 2 levels: first level generate macro nodes, second
14 level populate each macro node
15 """
17 def __init__(
18 self,
19 dag_generator,
20 num_macro_nodes,
21 num_micro_nodes,
22 max_num_local_nodes=4,
23 rng=None,
24 ):
25 rng = coerce_rng(rng)
26 self.dag_generator = dag_generator
27 self.num_macro_nodes = num_macro_nodes
28 self.num_micro_nodes = num_micro_nodes
29 self.max_num_local_nodes = max_num_local_nodes
31 self.global_dag_indexer = None
32 self.dag_backbone = None
33 self.dict_macro_node2dag = {}
34 self.dag_refined = None
35 self.rng = rng
36 self.dag_manipulator = None
38 def populate_macro_node(self):
39 """
40 replace a macro node into a DAG
41 """
42 # iterate each macro node
43 for name in self.dag_backbone.list_node_names:
44 num_nodes = self.num_micro_nodes
45 if num_nodes is None:
46 num_nodes = self.rng.integers(2, self.max_num_local_nodes + 1)
47 self.dict_macro_node2dag[name] = self.dag_generator.gen_dag(
48 num_nodes=num_nodes,
49 prefix=name,
50 target_num_confounder=2,
51 )
52 self.global_dag_indexer = DAGStackIndexer(self)
54 def interconnection(self):
55 """
56 connect macro nodes with edges
57 """
58 # iterate over the Macro-DAG edges
59 for arc in self.dag_backbone.list_arcs:
60 self.connect_macro_node_via_local_node(arc)
62 def connect_macro_node_via_local_node(self, arc):
63 """
64 connect macro-DAG node edge (i,j) via local nodes
65 """
66 macro_arrow_tail, macro_arrow_head = arc
67 _, ind_local_tail = self.dict_macro_node2dag[macro_arrow_tail].sample_node()
68 _, ind_local_head = self.dict_macro_node2dag[macro_arrow_head].sample_node()
70 ind_macro_tail = self.dag_backbone.get_node_ind(macro_arrow_tail)
71 ind_macro_head = self.dag_backbone.get_node_ind(macro_arrow_head)
73 ind_global_tail = self.global_dag_indexer.get_global_ind(
74 ind_macro_tail, ind_local_tail
75 )
76 ind_global_head = self.global_dag_indexer.get_global_ind(
77 ind_macro_head, ind_local_head
78 )
80 self.dag_refined.add_arc_ind(ind_global_tail, ind_global_head)
82 def inject_additional_confounder(self):
83 """
84 make confounder in the big graph
85 """
86 obj_gen_weight = WeightGenWishart(rng=self.rng)
87 self.dag_manipulator = DAGManipulator(self.dag_refined,
88 obj_gen_weight, self.rng)
90 ind_arbitrary = self.dag_refined.get_top_last()
91 self.dag_manipulator.mk_confound(ind_arbitrary)
92 print(self.dag_refined.num_confounder)
93 ind_arbitrary = self.dag_refined.climb(ind_arbitrary)
94 self.dag_manipulator.mk_confound(ind_arbitrary)
95 print(self.dag_refined.num_confounder)
97 def run(self):
98 """
99 generation
100 """
102 # generate dag_backbone DAG with only macro nodes
103 self.dag_backbone = self.dag_generator.gen_dag(
104 self.num_macro_nodes, target_num_confounder=2
105 )
106 self.populate_macro_node()
107 self.interconnection()
108 self.dag_refined.check()
109 self.inject_additional_confounder()
110 return self.dag_refined