Coverage for causalspyne/main.py: 92%
51 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2generate DAG and its marginal DAG
3"""
5from datetime import datetime
7try:
8 from contextlib import chdir
9except Exception:
10 from causalspyne.py3_9_10_compatibility import chdir
12from pathlib import Path
14import matplotlib.pyplot as plt
15from numpy.random import default_rng
17from causalspyne.gen_dag_2level import GenDAG2Level
18from causalspyne.dag_gen import GenDAG
19from causalspyne.dag_viewer import DAGView
20from causalspyne.dag2ancestral import DAG2Ancestral
22from causalspyne.draw_dags import draw_dags_nx
25def gen_partially_observed(
26 degree=2,
27 list_confounder2hide=[0.5, 0.9],
28 size_micro_node_dag=4,
29 num_macro_nodes=4,
30 num_sample=200,
31 output_dir="output/",
32 rng=default_rng(),
33 graphviz=False,
34 plot=True,
35):
36 """
37 sole function as user interface
38 """
39 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_")
40 Path(output_dir).mkdir(parents=True, exist_ok=True)
42 simple_dag_gen = GenDAG(num_nodes=size_micro_node_dag,
43 degree=degree, rng=rng)
45 # num_macro_nodes will overwrite behavior
46 dag_gen = GenDAG2Level(
47 dag_generator=simple_dag_gen,
48 num_macro_nodes=num_macro_nodes, rng=rng
49 )
50 dag = dag_gen.run()
51 dag.to_binary_csv(benchpress=False,
52 name=output_dir + f"ground_truth_dag_{timestamp}d.csv")
54 subview = DAGView(dag=dag, rng=rng)
55 return re_hide(subview, dag, num_sample, list_confounder2hide, output_dir,
56 graphviz, timestamp, plot=True)
59def ordered_ind_col2global_ind(inds_cols, subview_global_inds):
60 """
61 given a predicted causal order in the form of column indices, transform it
62 into global index of ground truth DAG
63 """
64 list_global_inds = [subview_global_inds[ind_col] for ind_col in inds_cols]
65 return list_global_inds
68def re_hide(subview, dag, num_sample, list_confounder2hide, output_dir,
69 graphviz, timestamp, plot=True):
70 subview.run(
71 num_samples=num_sample, confound=True,
72 list_nodes2hide=list_confounder2hide
73 )
74 with chdir(output_dir):
75 subview.to_csv()
76 str_node2hide = subview.str_node2hide
78 dag2ancestral = DAG2Ancestral(dag.mat_adjacency)
79 list_confounder2hide_global_ind = subview.list_global_inds_nodes2hide
80 pred_ancestral_graph_mat = dag2ancestral.run(
81 list_confounder2hide_global_ind)
83 if plot:
84 fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
85 mtitle = "hide_" + str_node2hide
86 fig.suptitle(mtitle) # super-title
88 # ax1
89 dag.visualize(title="DAG", ax=ax1, graphviz=graphviz)
90 ax1.set_title("DAG")
92 # ax2
93 draw_dags_nx(
94 pred_ancestral_graph_mat,
95 dict_ind2name={
96 i: name for i, name in enumerate(sorted(subview.node_names))
97 },
98 title="ancestral",
99 ax=ax2,
100 graphviz=graphviz,
101 )
102 ax2.set_title("ancestral")
103 # ax3
104 subview.visualize(
105 title="subDAG", ax=ax3, graphviz=graphviz
106 )
107 ax3.set_title("subDAG")
109 with chdir(output_dir):
110 subview.to_csv()
111 if plot:
112 fig.savefig(f"graph_compare_{timestamp}dags.pdf", format="pdf")
113 fig.savefig(f"graph_compare_{timestamp}dags.svg", format="svg")
114 with open("hidden_nodes.csv", "w") as outfile:
115 outfile.write(
116 ",".join(str(node) for node in
117 subview._list_global_inds_unobserved)
118 )
119 return subview