Coverage for causalspyne/dag_viewer.py: 84%
99 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2create different views for the same DAG by hiding some variables
3"""
5import warnings
7import numpy as np
8from numpy.random import default_rng
9import pandas as pd
11from causalspyne.data_gen import DataGen
12from causalspyne.dag_interface import MatDAG
15def process_list2hide(list_ind_or_percentage, total_num):
16 """
17 list_ind_or_percentage can either be a list or a scalar float
18 """
19 if len(list_ind_or_percentage) > total_num:
20 raise RuntimeError(
21 f"there are {total_num} confounders to hide, less \
22 than the length of {list_ind_or_percentage}"
23 )
25 list_ind = [
26 min(int(ele * total_num), total_num - 1)
27 if isinstance(ele, float) else ele
28 for ele in list_ind_or_percentage
29 ]
31 list_ind = list(set(list_ind))
33 if max(list_ind) > total_num:
34 raise RuntimeError(
35 f"max value in {list_ind_or_percentage} is bigger \
36 than total number of variables {total_num} to hide"
37 )
39 return list_ind
42class DAGView:
43 """
44 with ground truth DAG intact, only show subgraph
45 """
47 def __init__(self, dag, rng=default_rng(0)):
48 self._dag = dag
49 # there is no need to use a full DAG to represent subdag
50 # since sub-dag is not responsible for data generation
51 self._sub_dag = None
52 self._data_arr = None
53 self._subset_data_arr = None
54 self._list_global_inds_unobserved = None
55 self._list_global_inds_observed = None
56 self.data_gen = DataGen(self._dag, rng=rng)
57 self._list_nodes2hide = None
58 self._success = False
60 @property
61 def dag(self):
62 return self._dag
64 def run(self, num_samples, list_nodes2hide=None, confound=False):
65 """
66 generate subgraph adjcency matrix and corresponding data
67 """
68 self._data_arr = self.data_gen.gen(num_samples)
69 if list_nodes2hide is None:
70 list_nodes2hide = [0]
71 if confound:
72 self.hide_confounder(list_nodes2hide)
73 else:
74 self.hide_top_order(list_nodes2hide)
76 def hide_confounder(self, list_toporder_confounder2hide_input):
77 """
78 given a list of index, hide the confounder according to the toplogical
79 order provided by the input index list_toporder_confounder2hide
80 then call self.hide
81 """
82 if not self._dag.list_confounder:
83 raise RuntimeError(
84 f"there are no confounders in the graph {self._dag} \
85 !"
86 )
88 list_toporder_confounder2hide = process_list2hide(
89 list_toporder_confounder2hide_input, len(self._dag.list_confounder)
90 )
92 list_ind_confounder_sorted = self._dag.list_top_order_sorted_confounder
94 list_toporder_confounder_sub = [
95 list_ind_confounder_sorted[i]
96 for i in list_toporder_confounder2hide
97 ]
99 self.hide_top_order(list_toporder_confounder_sub)
100 return True
102 def hide_top_order(self, list_toporder_unobserved):
103 """
104 hide variables according to a list of global index of topological sort
105 """
106 list_toporder_unobserved = process_list2hide(
107 list_toporder_unobserved, self._dag.num_nodes
108 )
110 # subset list
111 self._list_nodes2hide = [
112 self._dag.list_top_names[i] for i in list_toporder_unobserved
113 ]
115 # FIXME: change to logger
116 print("nodes to hide " + str(self._list_nodes2hide))
118 self._list_global_inds_unobserved = [
119 self._dag.list_ind_nodes_sorted[ind_top_order]
120 for ind_top_order in list_toporder_unobserved
121 ]
123 self._sub_dag = self._dag.subgraph(self._list_global_inds_unobserved)
124 self._subset_data_arr = np.delete(
125 self._data_arr, self._list_global_inds_unobserved, axis=1
126 )
127 self._success = True
129 @property
130 def data(self):
131 """
132 return data in numpy array format
133 """
134 return self._subset_data_arr
136 @property
137 def mat_adj(self):
138 """
139 return adj matrix
140 """
141 return self._sub_dag.mat_adjacency
143 def check_if_subview_done(self):
144 """
145 check if DAG marginalizatino successfull or not
146 """
147 if not self._success:
148 warnings.warn("no subview of DAG available, exit now!")
149 return
151 @property
152 def node_names(self):
153 # filter out observed variable
154 # _node_names = [name for (i, name) in
155 # enumerate(self._dag.list_node_names)
156 # if i not in self._list_global_inds_unobserved]
157 _node_names = self._sub_dag.list_node_names
158 _node_names_ind = ["X" + str(self._sub_dag._parent_list_node_names.index(name)) for
159 name in _node_names]
160 return _node_names_ind
162 def to_csv(self, title="data_subdag.csv"):
163 """
164 sub dataframe to csv
165 """
166 self.check_if_subview_done()
168 # FIXME: ensure self.node_names are consistent with self.data
169 df = pd.DataFrame(self.data, columns=self.node_names)
170 df.to_csv(title[:-4] + "_" + self.str_node2hide + title[-4:],
171 index=False)
173 subdag = MatDAG(self.mat_adj)
174 subdag.to_binary_csv()
176 @property
177 def list_global_inds_nodes2hide(self):
178 return self._list_global_inds_unobserved
180 @property
181 def col_inds(self):
182 subview_global_inds = \
183 [self.dag._dict_node_names2ind[name]
184 for name in self.dag.list_node_names
185 if name not in self.str_node2hide]
186 return subview_global_inds
188 @property
189 def list_global_inds_observed(self):
190 if self._list_global_inds_observed is None:
191 if self._list_global_inds_unobserved is None:
192 raise RuntimeError(
193 "global inds for unobserved not initialized yet!")
194 self._list_global_inds_observed = \
195 [item for item in self._dag.list_ind_nodes_sorted
196 if item not in self._list_global_inds_unobserved]
197 return self._list_global_inds_observed
199 @property
200 def str_node2hide(self):
201 """
202 string representation of nodes to hide(marginalize)
203 """
204 self.check_if_subview_done()
205 if self._list_nodes2hide is None:
206 raise RuntimeError("self._list_node2hide is None!")
207 _str_node2hide = "_".join(map(str, self._list_nodes2hide))
209 _str_node2hide_ind = "_".join(map(str, self._list_global_inds_unobserved))
211 return "_ind_".join([_str_node2hide, _str_node2hide_ind])
213 def visualize(self, **kwargs):
214 """
215 plot DAG
216 """
217 if not self._success:
218 warnings.warn("no subview of DAG available")
219 return
220 self._sub_dag.visualize(**kwargs)