Coverage for src/causalspyne/dag_viewer.py: 84%
100 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2create different views for the same DAG by hiding some variables
3"""
5import warnings
7import numpy as np
8import pandas as pd
10from causalspyne.data_gen import DataGen
11from causalspyne.dag_interface import MatDAG
12from causalspyne.utils_random import coerce_rng
15def process_list2hide(list_ind_or_percentage, total_num):
16 """
17 list_ind_or_percentage can either be a list or a scalar float
18 """
19 if len(list_ind_or_percentage) > total_num:
20 raise RuntimeError(
21 f"there are {total_num} confounders to hide, less \
22 than the length of {list_ind_or_percentage}"
23 )
25 list_ind = [
26 min(int(ele * total_num), total_num - 1)
27 if isinstance(ele, float) else ele
28 for ele in list_ind_or_percentage
29 ]
31 list_ind = list(set(list_ind))
33 if max(list_ind) > total_num:
34 raise RuntimeError(
35 f"max value in {list_ind_or_percentage} is bigger \
36 than total number of variables {total_num} to hide"
37 )
39 return list_ind
42class DAGView:
43 """
44 with ground truth DAG intact, only show subgraph
45 """
47 def __init__(self, dag, dft_noise, rng=None):
48 rng = coerce_rng(rng, seed=0)
49 self._dag = dag
50 # there is no need to use a full DAG to represent subdag
51 # since sub-dag is not responsible for data generation
52 self._sub_dag = None
53 self._data_arr = None
54 self._subset_data_arr = None
55 self._list_global_inds_unobserved = None
56 self._list_global_inds_observed = None
57 self.data_gen = DataGen(dag=self._dag,
58 dft_noise=dft_noise,
59 rng=rng)
60 self._list_nodes2hide = None
61 self._success = False
63 @property
64 def dag(self):
65 return self._dag
67 def run(self, num_samples, list_nodes2hide=None, confound=False):
68 """
69 generate subgraph adjcency matrix and corresponding data
70 """
71 self._data_arr = self.data_gen.gen(num_samples)
72 if list_nodes2hide is None:
73 list_nodes2hide = [0]
74 if confound:
75 self.hide_confounder(list_nodes2hide)
76 else:
77 self.hide_top_order(list_nodes2hide)
79 def hide_confounder(self, list_toporder_confounder2hide_input):
80 """
81 given a list of index, hide the confounder according to the toplogical
82 order provided by the input index list_toporder_confounder2hide
83 then call self.hide
84 """
85 if not self._dag.list_confounder:
86 raise RuntimeError(
87 f"there are no confounders in the graph {self._dag} \
88 !"
89 )
91 list_toporder_confounder2hide = process_list2hide(
92 list_toporder_confounder2hide_input, len(self._dag.list_confounder)
93 )
95 list_ind_confounder_sorted = self._dag.list_top_order_sorted_confounder
97 list_toporder_confounder_sub = [
98 list_ind_confounder_sorted[i]
99 for i in list_toporder_confounder2hide
100 ]
102 self.hide_top_order(list_toporder_confounder_sub)
103 return True
105 def hide_top_order(self, list_toporder_unobserved):
106 """
107 hide variables according to a list of global index of topological sort
108 """
109 list_toporder_unobserved = process_list2hide(
110 list_toporder_unobserved, self._dag.num_nodes
111 )
113 # subset list
114 self._list_nodes2hide = [
115 self._dag.list_top_names[i] for i in list_toporder_unobserved
116 ]
118 # FIXME: change to logger
119 print("nodes to hide " + str(self._list_nodes2hide))
121 self._list_global_inds_unobserved = [
122 self._dag.list_ind_nodes_sorted[ind_top_order]
123 for ind_top_order in list_toporder_unobserved
124 ]
126 self._sub_dag = self._dag.subgraph(self._list_global_inds_unobserved)
127 self._subset_data_arr = np.delete(
128 self._data_arr, self._list_global_inds_unobserved, axis=1
129 )
130 self._success = True
132 @property
133 def data(self):
134 """
135 return data in numpy array format
136 """
137 return self._subset_data_arr
139 @property
140 def mat_adj(self):
141 """
142 return adj matrix
143 """
144 return self._sub_dag.mat_adjacency
146 def check_if_subview_done(self):
147 """
148 check if DAG marginalizatino successfull or not
149 """
150 if not self._success:
151 warnings.warn("no subview of DAG available, exit now!")
152 return
154 @property
155 def node_names(self):
156 # filter out observed variable
157 # _node_names = [name for (i, name) in
158 # enumerate(self._dag.list_node_names)
159 # if i not in self._list_global_inds_unobserved]
160 _node_names = self._sub_dag.list_node_names
161 _node_names_ind = ["X" + str(self._sub_dag._parent_list_node_names.index(name)) for
162 name in _node_names]
163 return _node_names_ind
165 def to_csv(self, title="data_subdag.csv"):
166 """
167 sub dataframe to csv
168 """
169 self.check_if_subview_done()
171 # FIXME: ensure self.node_names are consistent with self.data
172 df = pd.DataFrame(self.data, columns=self.node_names)
173 df.to_csv(title[:-4] + "_" + self.str_node2hide + title[-4:],
174 index=False)
176 subdag = MatDAG(self.mat_adj)
177 subdag.to_binary_csv()
179 @property
180 def list_global_inds_nodes2hide(self):
181 return self._list_global_inds_unobserved
183 @property
184 def col_inds(self):
185 subview_global_inds = \
186 [self.dag._dict_node_names2ind[name]
187 for name in self.dag.list_node_names
188 if name not in self.str_node2hide]
189 return subview_global_inds
191 @property
192 def list_global_inds_observed(self):
193 if self._list_global_inds_observed is None:
194 if self._list_global_inds_unobserved is None:
195 raise RuntimeError(
196 "global inds for unobserved not initialized yet!")
197 self._list_global_inds_observed = \
198 [item for item in self._dag.list_ind_nodes_sorted
199 if item not in self._list_global_inds_unobserved]
200 return self._list_global_inds_observed
202 @property
203 def str_node2hide(self):
204 """
205 string representation of nodes to hide(marginalize)
206 """
207 self.check_if_subview_done()
208 if self._list_nodes2hide is None:
209 raise RuntimeError("self._list_node2hide is None!")
210 _str_node2hide = "_".join(map(str, self._list_nodes2hide))
212 _str_node2hide_ind = "_".join(map(str, self._list_global_inds_unobserved))
214 return "_ind_".join([_str_node2hide, _str_node2hide_ind])
216 def visualize(self, **kwargs):
217 """
218 plot DAG
219 """
220 if not self._success:
221 warnings.warn("no subview of DAG available")
222 return
223 self._sub_dag.visualize(**kwargs)