Coverage for causalspyne/gen_dag_2level.py: 100%
48 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
22-level DAG generation
3"""
5from numpy.random import default_rng
7from causalspyne.dag_stack_indexer import DAGStackIndexer
8from causalspyne.dag_manipulator import DAGManipulator
9from causalspyne.weight import WeightGenWishart
12class GenDAG2Level:
13 """
14 generate a DAG with 2 levels: first level generate macro nodes, second
15 level populate each macro node
16 """
18 def __init__(
19 self, dag_generator, num_macro_nodes, max_num_local_nodes=4, rng=default_rng()
20 ):
21 self.dag_generator = dag_generator
22 self.num_macro_nodes = num_macro_nodes
23 self.max_num_local_nodes = max_num_local_nodes
25 self.global_dag_indexer = None
26 self.dag_backbone = None
27 self.dict_macro_node2dag = {}
28 self.dag_refined = None
29 self.rng = rng
30 self.dag_manipulator = None
32 def populate_macro_node(self):
33 """
34 replace a macro node into a DAG
35 """
36 # iterate each macro node
37 for name in self.dag_backbone.list_node_names:
38 num_nodes = self.rng.integers(2, self.max_num_local_nodes + 1)
39 self.dict_macro_node2dag[name] = self.dag_generator.gen_dag(
40 num_nodes=num_nodes, prefix=name
41 )
42 self.global_dag_indexer = DAGStackIndexer(self)
44 def interconnection(self):
45 """
46 connect macro nodes with edges
47 """
48 # iterate over the Macro-DAG edges
49 for arc in self.dag_backbone.list_arcs:
50 self.connect_macro_node_via_local_node(arc)
52 def connect_macro_node_via_local_node(self, arc):
53 """
54 connect macro-DAG node edge (i,j) via local nodes
55 """
56 macro_arrow_tail, macro_arrow_head = arc
57 _, ind_local_tail = self.dict_macro_node2dag[macro_arrow_tail].sample_node()
58 _, ind_local_head = self.dict_macro_node2dag[macro_arrow_head].sample_node()
60 ind_macro_tail = self.dag_backbone.get_node_ind(macro_arrow_tail)
61 ind_macro_head = self.dag_backbone.get_node_ind(macro_arrow_head)
63 ind_global_tail = self.global_dag_indexer.get_global_ind(
64 ind_macro_tail, ind_local_tail
65 )
66 ind_global_head = self.global_dag_indexer.get_global_ind(
67 ind_macro_head, ind_local_head
68 )
70 self.dag_refined.add_arc_ind(ind_global_tail, ind_global_head)
72 def inject_additional_confounder(self):
73 """
74 make confounder in the big graph
75 """
76 obj_gen_weight = WeightGenWishart(rng=self.rng)
77 self.dag_manipulator = DAGManipulator(self.dag_refined,
78 obj_gen_weight, self.rng)
80 ind_arbitrary = self.dag_refined.get_top_last()
81 self.dag_manipulator.mk_confound(ind_arbitrary)
82 print(self.dag_refined.num_confounder)
83 ind_arbitrary = self.dag_refined.climb(ind_arbitrary)
84 self.dag_manipulator.mk_confound(ind_arbitrary)
85 print(self.dag_refined.num_confounder)
87 def run(self):
88 """
89 generation
90 """
92 # generate dag_backbone DAG with only macro nodes
93 self.dag_backbone = self.dag_generator.gen_dag(self.num_macro_nodes)
94 self.populate_macro_node()
95 self.interconnection()
96 self.dag_refined.check()
97 self.inject_additional_confounder()
98 return self.dag_refined