Coverage for causalspyne/dag_interface.py: 97%
144 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2class method for DAG operations and check
3"""
5import numpy as np
6from numpy.random import default_rng
7import pandas as pd
9from causalspyne.is_dag import is_dag
10from causalspyne.utils_topological_sort import topological_sort
11from causalspyne.draw_dags import draw_dags_nx
14def add_prefix(string, prefix="", separator="u"):
15 """
16 Adds a prefix to a string.
17 If the prefix is empty, returns the original string.
19 Args:
20 string (str): The original string.
21 prefix (str, optional): The prefix to add. Defaults to an empty string.
23 Returns:
24 str: The string with the prefix added.
25 """
26 if not prefix:
27 return string
28 return separator.join([prefix, string])
31class MatDAG:
32 """
33 DAG represented as a mat_adjacency
34 """
36 def __init__(
37 self,
38 mat_adjacency,
39 name_prefix="",
40 separator="_",
41 list_node_names=None,
42 parent_list_node_names=None,
43 rng=default_rng(0),
44 ):
45 """ """
46 self._obj_gen_weight = None
47 self.separator = separator
48 self.name_prefix = name_prefix
49 self.mat_adjacency = mat_adjacency
50 self._list_node_names = list_node_names
51 self._list_confounder = None
52 self._dict_node_names2ind = {}
53 self._parent_list_node_names = parent_list_node_names
54 self._init_map()
55 self._list_ind_nodes_sorted = None
56 self.rng = rng
58 def _init_map(self):
59 if self._list_node_names is not None:
60 self._dict_node_names2ind = {
61 name: i for (i, name) in enumerate(self._list_node_names)
62 }
64 @property
65 def list_confounder(self):
66 """
67 return list of confounders
68 """
69 nonzero_counts = np.count_nonzero(self.mat_adjacency, axis=0)
70 columns_with_more_than_one = np.where(nonzero_counts > 1)[0]
71 return list(columns_with_more_than_one)
73 def gen_dict_ind2node_na(self, hierarch_na=False):
74 """
75 utility function to have {1:"node_name"} dictionary for plotting
76 hierarch_na: if use hierarchical name maco-node-micro-node format
77 """
78 if hierarch_na:
79 mdict = {i: name for (i, name) in enumerate(self.list_node_names)}
80 elif self._parent_list_node_names is not None:
81 mdict = {
82 i: str(self._parent_list_node_names.index(name))
83 for (i, name) in enumerate(self.list_node_names)
84 }
85 else:
86 # FIXME: if self.list_node_names is list of string,
87 # then it will be ignored, mdict will be {0:integer} style
88 # always
89 mdict = {
90 i: str(self.list_node_names.index(name))
91 for (i, name) in enumerate(self.list_node_names)
92 }
93 return mdict
95 def check(self):
96 """
97 check if the matrix represent a DAG
98 """
99 if not is_dag(self.mat_adjacency):
100 raise RuntimeError("not a DAG")
101 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
102 if not is_dag(binary_adj_mat):
103 raise RuntimeError("not a DAG")
105 @property
106 def num_nodes(self):
107 """
108 number of nodes in DAG
109 """
110 return self.mat_adjacency.shape[0]
112 def gen_node_names(self):
113 """
114 get list of node names
115 """
116 self._list_node_names = [
117 add_prefix(string="v" + str(i), prefix=self.name_prefix)
118 for i in range(self.num_nodes)
119 ]
120 self._init_map()
122 def gen_node_names_stacked(self, dict_macro_node2dag):
123 self._list_node_names = []
124 for key, dag in dict_macro_node2dag.items():
125 self._list_node_names.extend(dag.list_node_names)
126 self._init_map()
128 def get_node_ind(self, node_name):
129 """
130 get the matrix index of the node name
131 """
132 return self._dict_node_names2ind[node_name]
134 @property
135 def list_node_names(self):
136 """
137 get the node names in a linear list
138 """
139 if self._list_node_names is None:
140 self.gen_node_names()
141 return self._list_node_names
143 @property
144 def list_arcs(self):
145 """
146 return the list of edges
147 """
148 list_i_j = list(zip(*self.mat_adjacency.nonzero()))
149 list_arcs = [
150 (self.list_node_names[tuple(ij)[0]],
151 self.list_node_names[tuple(ij)[1]])
152 for ij in list_i_j
153 ]
154 return list_arcs
156 def sample_node(self):
157 """
158 randomly chose a node
159 """
160 ind = self.rng.integers(0, self.mat_adjacency.shape[0])
161 name = self.list_node_names[ind]
162 return self.name_prefix + name, ind
164 def add_arc_ind(self, ind_tail, ind_head, weight=None):
165 """
166 add arc via index of tail and head
167 """
168 node_tail = self.list_node_names[ind_tail]
169 node_head = self.list_node_names[ind_head]
170 self.add_arc(node_tail, node_head, weight)
172 def add_arc(self, node_tail, node_head, weight=None):
173 """
174 add edge to adjacency matrix
175 """
176 ind_tail = self._dict_node_names2ind[node_tail]
177 ind_head = self._dict_node_names2ind[node_head]
178 if weight is None:
179 self.mat_adjacency[ind_tail, ind_head] = 1
180 else:
181 self.mat_adjacency[ind_tail, ind_head] = weight
183 def to_binary_csv(self, benchpress=True, name="adj.csv"):
184 """
185 adjacency matrix to csv format
186 """
187 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
188 if benchpress:
189 binary_adj_mat = np.transpose(binary_adj_mat)
190 df = pd.DataFrame(binary_adj_mat, columns=self.list_node_names)
191 df.to_csv(name, index=False)
193 def topological_sort(self):
194 """
195 topological sort DAG into list of node index
196 """
197 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
198 self._list_ind_nodes_sorted = topological_sort(binary_adj_mat)
199 return self._list_ind_nodes_sorted
201 @property
202 def list_ind_nodes_sorted(self):
203 """
204 get global node index topologically sorted
205 """
206 if self._list_ind_nodes_sorted is None:
207 self.topological_sort()
208 return self._list_ind_nodes_sorted
210 def get_list_parents_inds(self, ind_node):
211 """
212 get list of parents nodes
213 """
214 # np.nonzero(x)
215 # returns (array([0, 1, 2, 2]), array([0, 1, 0, 1]))
216 # assume lower triangular matrix as adjacency matrix
217 # matrix[i, j]=1 indicate arrow j->i
218 submatrix = self.mat_adjacency[ind_node, :]
219 vector = submatrix.flatten()
220 list_inds = np.nonzero(vector)[0].tolist()
221 return list_inds
223 def get_weights_from_list_parents(self, ind_node):
224 """
225 get incoming edge weights
226 """
227 list_parents_inds = self.get_list_parents_inds(ind_node)
228 sub_matrix = self.mat_adjacency[ind_node, list_parents_inds]
229 return sub_matrix
231 def __str__(self):
232 return str(self.mat_adjacency)
234 def __repr__(self):
235 return str(self.mat_adjacency)
237 def subgraph(self, list_ind_unobserved):
238 """
239 subset adjacency matrix by deleting unobserved variables
240 """
241 # delete first axis
242 temp_mat_row = np.delete(
243 self.mat_adjacency, list_ind_unobserved, axis=0)
245 # delete second axis
246 mat_adj_subgraph = np.delete(
247 temp_mat_row, list_ind_unobserved, axis=1)
249 # filter out subgraph node names
250 list_node_names_subgraph = [
251 x
252 for i, x in enumerate(self.list_node_names)
253 if i not in list_ind_unobserved
254 ]
256 subdag = MatDAG(
257 mat_adj_subgraph,
258 list_node_names=list_node_names_subgraph,
259 parent_list_node_names=self.list_node_names,
260 rng=self.rng,
261 )
262 return subdag
264 def visualize(self, title="dag", hierarch_na=False, ax=None,
265 graphviz=False):
266 """
267 draw dag using networkx
268 """
269 draw_dags_nx(
270 self.mat_adjacency,
271 dict_ind2name=self.gen_dict_ind2node_na(hierarch_na),
272 title=title,
273 ax=ax,
274 graphviz=graphviz
275 )
277 @property
278 def list_top_names(self):
279 """
280 return list of node names in toplogical order
281 """
282 list_top_names = [self.list_node_names[i] for i in
283 self.list_ind_nodes_sorted]
284 return list_top_names
286 def global_arbitrary_ind2topind(self, ind_global_arbitrary):
287 ind_top = self.list_ind_nodes_sorted.index(ind_global_arbitrary)
288 return ind_top
290 def top_ind2global_arbitrary(self, ind_top):
291 ind_global_arbitrary = self.list_ind_nodes_sorted[ind_top]
292 return ind_global_arbitrary
294 def get_top_last(self):
295 ind_global_arbitrary = self.list_ind_nodes_sorted[-1]
296 return ind_global_arbitrary
298 def climb(self, ind_arbitrary):
299 ind_top = self.global_arbitrary_ind2topind(ind_arbitrary)
300 if ind_top - 1 >= 0:
301 ind_arbi = self.top_ind2global_arbitrary(ind_top - 1)
302 return ind_arbi
303 return None
305 @property
306 def num_confounder(self):
307 return len(self.list_confounder)
309 @property
310 def list_top_order_sorted_confounder(self):
311 list_top_oder_confounder = [
312 self.list_ind_nodes_sorted.index(confounder)
313 for confounder in self.list_confounder
314 ]
315 list_top_oder_confounder_sorted = sorted(list_top_oder_confounder)
316 return list_top_oder_confounder_sorted