Coverage for src/causalspyne/gen_dag_2level.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

22-level DAG generation 

3""" 

4 

5from causalspyne.dag_stack_indexer import DAGStackIndexer 

6from causalspyne.dag_manipulator import DAGManipulator 

7from causalspyne.weight import WeightGenWishart 

8from causalspyne.utils_random import coerce_rng 

9 

10 

11class GenDAG2Level: 

12 """ 

13 generate a DAG with 2 levels: first level generate macro nodes, second 

14 level populate each macro node 

15 """ 

16 

17 def __init__( 

18 self, 

19 dag_generator, 

20 num_macro_nodes, 

21 num_micro_nodes, 

22 max_num_local_nodes=4, 

23 rng=None, 

24 ): 

25 rng = coerce_rng(rng) 

26 self.dag_generator = dag_generator 

27 self.num_macro_nodes = num_macro_nodes 

28 self.num_micro_nodes = num_micro_nodes 

29 self.max_num_local_nodes = max_num_local_nodes 

30 

31 self.global_dag_indexer = None 

32 self.dag_backbone = None 

33 self.dict_macro_node2dag = {} 

34 self.dag_refined = None 

35 self.rng = rng 

36 self.dag_manipulator = None 

37 

38 def populate_macro_node(self): 

39 """ 

40 replace a macro node into a DAG 

41 """ 

42 # iterate each macro node 

43 for name in self.dag_backbone.list_node_names: 

44 num_nodes = self.num_micro_nodes 

45 if num_nodes is None: 

46 num_nodes = self.rng.integers(2, self.max_num_local_nodes + 1) 

47 self.dict_macro_node2dag[name] = self.dag_generator.gen_dag( 

48 num_nodes=num_nodes, 

49 prefix=name, 

50 target_num_confounder=2, 

51 ) 

52 self.global_dag_indexer = DAGStackIndexer(self) 

53 

54 def interconnection(self): 

55 """ 

56 connect macro nodes with edges 

57 """ 

58 # iterate over the Macro-DAG edges 

59 for arc in self.dag_backbone.list_arcs: 

60 self.connect_macro_node_via_local_node(arc) 

61 

62 def connect_macro_node_via_local_node(self, arc): 

63 """ 

64 connect macro-DAG node edge (i,j) via local nodes 

65 """ 

66 macro_arrow_tail, macro_arrow_head = arc 

67 _, ind_local_tail = self.dict_macro_node2dag[macro_arrow_tail].sample_node() 

68 _, ind_local_head = self.dict_macro_node2dag[macro_arrow_head].sample_node() 

69 

70 ind_macro_tail = self.dag_backbone.get_node_ind(macro_arrow_tail) 

71 ind_macro_head = self.dag_backbone.get_node_ind(macro_arrow_head) 

72 

73 ind_global_tail = self.global_dag_indexer.get_global_ind( 

74 ind_macro_tail, ind_local_tail 

75 ) 

76 ind_global_head = self.global_dag_indexer.get_global_ind( 

77 ind_macro_head, ind_local_head 

78 ) 

79 

80 self.dag_refined.add_arc_ind(ind_global_tail, ind_global_head) 

81 

82 def inject_additional_confounder(self): 

83 """ 

84 make confounder in the big graph 

85 """ 

86 obj_gen_weight = WeightGenWishart(rng=self.rng) 

87 self.dag_manipulator = DAGManipulator(self.dag_refined, 

88 obj_gen_weight, self.rng) 

89 

90 ind_arbitrary = self.dag_refined.get_top_last() 

91 self.dag_manipulator.mk_confound(ind_arbitrary) 

92 print(self.dag_refined.num_confounder) 

93 ind_arbitrary = self.dag_refined.climb(ind_arbitrary) 

94 self.dag_manipulator.mk_confound(ind_arbitrary) 

95 print(self.dag_refined.num_confounder) 

96 

97 def run(self): 

98 """ 

99 generation 

100 """ 

101 

102 # generate dag_backbone DAG with only macro nodes 

103 self.dag_backbone = self.dag_generator.gen_dag( 

104 self.num_macro_nodes, target_num_confounder=2 

105 ) 

106 self.populate_macro_node() 

107 self.interconnection() 

108 self.dag_refined.check() 

109 self.inject_additional_confounder() 

110 return self.dag_refined