Coverage for src/causalspyne/main.py: 91%
54 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2generate DAG and its marginal DAG
3"""
5from datetime import datetime
6try:
7 from contextlib import chdir
8except Exception:
9 from causalspyne.py3_9_10_compatibility import chdir
11from pathlib import Path
13import matplotlib.pyplot as plt
15from causalspyne.gen_dag_2level import GenDAG2Level
16from causalspyne.dag_gen import GenDAG
17from causalspyne.dag_viewer import DAGView
18from causalspyne.dag2ancestral import DAG2Ancestral
20from causalspyne.draw_dags import draw_dags_nx
21from causalspyne.utils_random import coerce_rng
24def gen_partially_observed(
25 degree=2,
26 list_confounder2hide=None,
27 size_micro_node_dag=4,
28 max_num_local_nodes=4,
29 num_macro_nodes=4,
30 num_sample=200,
31 output_dir="output/",
32 rng=None,
33 dft_noise="Gaussian",
34 graphviz=False,
35 plot=True,
36):
37 """
38 sole function as user interface
39 """
40 if list_confounder2hide is None:
41 list_confounder2hide = [0.5, 0.9]
42 rng = coerce_rng(rng)
43 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_")
44 output_dir = Path(output_dir)
45 output_dir.mkdir(parents=True, exist_ok=True)
47 simple_dag_gen = GenDAG(num_nodes=size_micro_node_dag,
48 degree=degree, rng=rng)
50 dag_gen = GenDAG2Level(
51 dag_generator=simple_dag_gen,
52 num_macro_nodes=num_macro_nodes,
53 num_micro_nodes=size_micro_node_dag,
54 max_num_local_nodes=max_num_local_nodes,
55 rng=rng,
56 )
57 dag = dag_gen.run()
58 dag.to_binary_csv(benchpress=False,
59 name=output_dir / f"ground_truth_dag_{timestamp}d.csv")
61 subview = DAGView(dag=dag, rng=rng, dft_noise=dft_noise)
62 return re_hide(subview, dag, num_sample, list_confounder2hide, output_dir,
63 graphviz, timestamp, plot=plot)
66def ordered_ind_col2global_ind(inds_cols, subview_global_inds):
67 """
68 given a predicted causal order in the form of column indices, transform it
69 into global index of ground truth DAG
70 """
71 list_global_inds = [subview_global_inds[ind_col] for ind_col in inds_cols]
72 return list_global_inds
75def re_hide(subview, dag, num_sample, list_confounder2hide, output_dir,
76 graphviz, timestamp, plot=True):
77 subview.run(
78 num_samples=num_sample, confound=True,
79 list_nodes2hide=list_confounder2hide
80 )
81 str_node2hide = subview.str_node2hide
83 dag2ancestral = DAG2Ancestral(dag.mat_adjacency)
84 list_confounder2hide_global_ind = subview.list_global_inds_nodes2hide
85 pred_ancestral_graph_mat = dag2ancestral.run(
86 list_confounder2hide_global_ind)
88 if plot:
89 fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
90 mtitle = "hide_" + str_node2hide
91 fig.suptitle(mtitle) # super-title
93 # ax1
94 dag.visualize(title="DAG", ax=ax1, graphviz=graphviz)
95 ax1.set_title("DAG")
97 # ax2
98 draw_dags_nx(
99 pred_ancestral_graph_mat,
100 dict_ind2name={
101 i: name for i, name in enumerate(sorted(subview.node_names))
102 },
103 title="ancestral",
104 ax=ax2,
105 graphviz=graphviz,
106 )
107 ax2.set_title("ancestral")
108 # ax3
109 subview.visualize(
110 title="subDAG", ax=ax3, graphviz=graphviz
111 )
112 ax3.set_title("subDAG")
114 with chdir(output_dir):
115 subview.to_csv()
116 if plot:
117 fig.savefig(f"graph_compare_{timestamp}dags.pdf", format="pdf")
118 fig.savefig(f"graph_compare_{timestamp}dags.svg", format="svg")
119 plt.close(fig)
120 with open("hidden_nodes.csv", "w") as outfile:
121 outfile.write(
122 ",".join(str(node) for node in
123 subview._list_global_inds_unobserved)
124 )
125 return subview