Coverage for causalspyne/main.py: 92%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2generate DAG and its marginal DAG 

3""" 

4 

5from datetime import datetime 

6 

7try: 

8 from contextlib import chdir 

9except Exception: 

10 from causalspyne.py3_9_10_compatibility import chdir 

11 

12from pathlib import Path 

13 

14import matplotlib.pyplot as plt 

15from numpy.random import default_rng 

16 

17from causalspyne.gen_dag_2level import GenDAG2Level 

18from causalspyne.dag_gen import GenDAG 

19from causalspyne.dag_viewer import DAGView 

20from causalspyne.dag2ancestral import DAG2Ancestral 

21 

22from causalspyne.draw_dags import draw_dags_nx 

23 

24 

25def gen_partially_observed( 

26 degree=2, 

27 list_confounder2hide=[0.5, 0.9], 

28 size_micro_node_dag=4, 

29 num_macro_nodes=4, 

30 num_sample=200, 

31 output_dir="output/", 

32 rng=default_rng(), 

33 graphviz=False, 

34 plot=True, 

35): 

36 """ 

37 sole function as user interface 

38 """ 

39 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_") 

40 Path(output_dir).mkdir(parents=True, exist_ok=True) 

41 

42 simple_dag_gen = GenDAG(num_nodes=size_micro_node_dag, 

43 degree=degree, rng=rng) 

44 

45 # num_macro_nodes will overwrite behavior 

46 dag_gen = GenDAG2Level( 

47 dag_generator=simple_dag_gen, 

48 num_macro_nodes=num_macro_nodes, rng=rng 

49 ) 

50 dag = dag_gen.run() 

51 dag.to_binary_csv(benchpress=False, 

52 name=output_dir + f"ground_truth_dag_{timestamp}d.csv") 

53 

54 subview = DAGView(dag=dag, rng=rng) 

55 return re_hide(subview, dag, num_sample, list_confounder2hide, output_dir, 

56 graphviz, timestamp, plot=True) 

57 

58 

59def ordered_ind_col2global_ind(inds_cols, subview_global_inds): 

60 """ 

61 given a predicted causal order in the form of column indices, transform it 

62 into global index of ground truth DAG 

63 """ 

64 list_global_inds = [subview_global_inds[ind_col] for ind_col in inds_cols] 

65 return list_global_inds 

66 

67 

68def re_hide(subview, dag, num_sample, list_confounder2hide, output_dir, 

69 graphviz, timestamp, plot=True): 

70 subview.run( 

71 num_samples=num_sample, confound=True, 

72 list_nodes2hide=list_confounder2hide 

73 ) 

74 with chdir(output_dir): 

75 subview.to_csv() 

76 str_node2hide = subview.str_node2hide 

77 

78 dag2ancestral = DAG2Ancestral(dag.mat_adjacency) 

79 list_confounder2hide_global_ind = subview.list_global_inds_nodes2hide 

80 pred_ancestral_graph_mat = dag2ancestral.run( 

81 list_confounder2hide_global_ind) 

82 

83 if plot: 

84 fig, (ax1, ax2, ax3) = plt.subplots(1, 3) 

85 mtitle = "hide_" + str_node2hide 

86 fig.suptitle(mtitle) # super-title 

87 

88 # ax1 

89 dag.visualize(title="DAG", ax=ax1, graphviz=graphviz) 

90 ax1.set_title("DAG") 

91 

92 # ax2 

93 draw_dags_nx( 

94 pred_ancestral_graph_mat, 

95 dict_ind2name={ 

96 i: name for i, name in enumerate(sorted(subview.node_names)) 

97 }, 

98 title="ancestral", 

99 ax=ax2, 

100 graphviz=graphviz, 

101 ) 

102 ax2.set_title("ancestral") 

103 # ax3 

104 subview.visualize( 

105 title="subDAG", ax=ax3, graphviz=graphviz 

106 ) 

107 ax3.set_title("subDAG") 

108 

109 with chdir(output_dir): 

110 subview.to_csv() 

111 if plot: 

112 fig.savefig(f"graph_compare_{timestamp}dags.pdf", format="pdf") 

113 fig.savefig(f"graph_compare_{timestamp}dags.svg", format="svg") 

114 with open("hidden_nodes.csv", "w") as outfile: 

115 outfile.write( 

116 ",".join(str(node) for node in 

117 subview._list_global_inds_unobserved) 

118 ) 

119 return subview