Coverage for src/causalspyne/main.py: 91%

54 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2generate DAG and its marginal DAG 

3""" 

4 

5from datetime import datetime 

6try: 

7 from contextlib import chdir 

8except Exception: 

9 from causalspyne.py3_9_10_compatibility import chdir 

10 

11from pathlib import Path 

12 

13import matplotlib.pyplot as plt 

14 

15from causalspyne.gen_dag_2level import GenDAG2Level 

16from causalspyne.dag_gen import GenDAG 

17from causalspyne.dag_viewer import DAGView 

18from causalspyne.dag2ancestral import DAG2Ancestral 

19 

20from causalspyne.draw_dags import draw_dags_nx 

21from causalspyne.utils_random import coerce_rng 

22 

23 

24def gen_partially_observed( 

25 degree=2, 

26 list_confounder2hide=None, 

27 size_micro_node_dag=4, 

28 max_num_local_nodes=4, 

29 num_macro_nodes=4, 

30 num_sample=200, 

31 output_dir="output/", 

32 rng=None, 

33 dft_noise="Gaussian", 

34 graphviz=False, 

35 plot=True, 

36): 

37 """ 

38 sole function as user interface 

39 """ 

40 if list_confounder2hide is None: 

41 list_confounder2hide = [0.5, 0.9] 

42 rng = coerce_rng(rng) 

43 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_") 

44 output_dir = Path(output_dir) 

45 output_dir.mkdir(parents=True, exist_ok=True) 

46 

47 simple_dag_gen = GenDAG(num_nodes=size_micro_node_dag, 

48 degree=degree, rng=rng) 

49 

50 dag_gen = GenDAG2Level( 

51 dag_generator=simple_dag_gen, 

52 num_macro_nodes=num_macro_nodes, 

53 num_micro_nodes=size_micro_node_dag, 

54 max_num_local_nodes=max_num_local_nodes, 

55 rng=rng, 

56 ) 

57 dag = dag_gen.run() 

58 dag.to_binary_csv(benchpress=False, 

59 name=output_dir / f"ground_truth_dag_{timestamp}d.csv") 

60 

61 subview = DAGView(dag=dag, rng=rng, dft_noise=dft_noise) 

62 return re_hide(subview, dag, num_sample, list_confounder2hide, output_dir, 

63 graphviz, timestamp, plot=plot) 

64 

65 

66def ordered_ind_col2global_ind(inds_cols, subview_global_inds): 

67 """ 

68 given a predicted causal order in the form of column indices, transform it 

69 into global index of ground truth DAG 

70 """ 

71 list_global_inds = [subview_global_inds[ind_col] for ind_col in inds_cols] 

72 return list_global_inds 

73 

74 

75def re_hide(subview, dag, num_sample, list_confounder2hide, output_dir, 

76 graphviz, timestamp, plot=True): 

77 subview.run( 

78 num_samples=num_sample, confound=True, 

79 list_nodes2hide=list_confounder2hide 

80 ) 

81 str_node2hide = subview.str_node2hide 

82 

83 dag2ancestral = DAG2Ancestral(dag.mat_adjacency) 

84 list_confounder2hide_global_ind = subview.list_global_inds_nodes2hide 

85 pred_ancestral_graph_mat = dag2ancestral.run( 

86 list_confounder2hide_global_ind) 

87 

88 if plot: 

89 fig, (ax1, ax2, ax3) = plt.subplots(1, 3) 

90 mtitle = "hide_" + str_node2hide 

91 fig.suptitle(mtitle) # super-title 

92 

93 # ax1 

94 dag.visualize(title="DAG", ax=ax1, graphviz=graphviz) 

95 ax1.set_title("DAG") 

96 

97 # ax2 

98 draw_dags_nx( 

99 pred_ancestral_graph_mat, 

100 dict_ind2name={ 

101 i: name for i, name in enumerate(sorted(subview.node_names)) 

102 }, 

103 title="ancestral", 

104 ax=ax2, 

105 graphviz=graphviz, 

106 ) 

107 ax2.set_title("ancestral") 

108 # ax3 

109 subview.visualize( 

110 title="subDAG", ax=ax3, graphviz=graphviz 

111 ) 

112 ax3.set_title("subDAG") 

113 

114 with chdir(output_dir): 

115 subview.to_csv() 

116 if plot: 

117 fig.savefig(f"graph_compare_{timestamp}dags.pdf", format="pdf") 

118 fig.savefig(f"graph_compare_{timestamp}dags.svg", format="svg") 

119 plt.close(fig) 

120 with open("hidden_nodes.csv", "w") as outfile: 

121 outfile.write( 

122 ",".join(str(node) for node in 

123 subview._list_global_inds_unobserved) 

124 ) 

125 return subview