Coverage for causalspyne/gen_dag_2level.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

22-level DAG generation 

3""" 

4 

5from numpy.random import default_rng 

6 

7from causalspyne.dag_stack_indexer import DAGStackIndexer 

8from causalspyne.dag_manipulator import DAGManipulator 

9from causalspyne.weight import WeightGenWishart 

10 

11 

12class GenDAG2Level: 

13 """ 

14 generate a DAG with 2 levels: first level generate macro nodes, second 

15 level populate each macro node 

16 """ 

17 

18 def __init__( 

19 self, dag_generator, num_macro_nodes, max_num_local_nodes=4, rng=default_rng() 

20 ): 

21 self.dag_generator = dag_generator 

22 self.num_macro_nodes = num_macro_nodes 

23 self.max_num_local_nodes = max_num_local_nodes 

24 

25 self.global_dag_indexer = None 

26 self.dag_backbone = None 

27 self.dict_macro_node2dag = {} 

28 self.dag_refined = None 

29 self.rng = rng 

30 self.dag_manipulator = None 

31 

32 def populate_macro_node(self): 

33 """ 

34 replace a macro node into a DAG 

35 """ 

36 # iterate each macro node 

37 for name in self.dag_backbone.list_node_names: 

38 num_nodes = self.rng.integers(2, self.max_num_local_nodes + 1) 

39 self.dict_macro_node2dag[name] = self.dag_generator.gen_dag( 

40 num_nodes=num_nodes, prefix=name 

41 ) 

42 self.global_dag_indexer = DAGStackIndexer(self) 

43 

44 def interconnection(self): 

45 """ 

46 connect macro nodes with edges 

47 """ 

48 # iterate over the Macro-DAG edges 

49 for arc in self.dag_backbone.list_arcs: 

50 self.connect_macro_node_via_local_node(arc) 

51 

52 def connect_macro_node_via_local_node(self, arc): 

53 """ 

54 connect macro-DAG node edge (i,j) via local nodes 

55 """ 

56 macro_arrow_tail, macro_arrow_head = arc 

57 _, ind_local_tail = self.dict_macro_node2dag[macro_arrow_tail].sample_node() 

58 _, ind_local_head = self.dict_macro_node2dag[macro_arrow_head].sample_node() 

59 

60 ind_macro_tail = self.dag_backbone.get_node_ind(macro_arrow_tail) 

61 ind_macro_head = self.dag_backbone.get_node_ind(macro_arrow_head) 

62 

63 ind_global_tail = self.global_dag_indexer.get_global_ind( 

64 ind_macro_tail, ind_local_tail 

65 ) 

66 ind_global_head = self.global_dag_indexer.get_global_ind( 

67 ind_macro_head, ind_local_head 

68 ) 

69 

70 self.dag_refined.add_arc_ind(ind_global_tail, ind_global_head) 

71 

72 def inject_additional_confounder(self): 

73 """ 

74 make confounder in the big graph 

75 """ 

76 obj_gen_weight = WeightGenWishart(rng=self.rng) 

77 self.dag_manipulator = DAGManipulator(self.dag_refined, 

78 obj_gen_weight, self.rng) 

79 

80 ind_arbitrary = self.dag_refined.get_top_last() 

81 self.dag_manipulator.mk_confound(ind_arbitrary) 

82 print(self.dag_refined.num_confounder) 

83 ind_arbitrary = self.dag_refined.climb(ind_arbitrary) 

84 self.dag_manipulator.mk_confound(ind_arbitrary) 

85 print(self.dag_refined.num_confounder) 

86 

87 def run(self): 

88 """ 

89 generation 

90 """ 

91 

92 # generate dag_backbone DAG with only macro nodes 

93 self.dag_backbone = self.dag_generator.gen_dag(self.num_macro_nodes) 

94 self.populate_macro_node() 

95 self.interconnection() 

96 self.dag_refined.check() 

97 self.inject_additional_confounder() 

98 return self.dag_refined