Coverage for src/causalspyne/dag_interface.py: 97%
145 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2class method for DAG operations and check
3"""
5import numpy as np
6import pandas as pd
8from causalspyne.is_dag import is_dag
9from causalspyne.utils_topological_sort import topological_sort
10from causalspyne.draw_dags import draw_dags_nx
11from causalspyne.utils_random import coerce_rng
14def add_prefix(string, prefix="", separator="u"):
15 """
16 Adds a prefix to a string.
17 If the prefix is empty, returns the original string.
19 Args:
20 string (str): The original string.
21 prefix (str, optional): The prefix to add. Defaults to an empty string.
23 Returns:
24 str: The string with the prefix added.
25 """
26 if not prefix:
27 return string
28 return separator.join([prefix, string])
31class MatDAG:
32 """
33 DAG represented as a mat_adjacency
34 """
36 def __init__(
37 self,
38 mat_adjacency,
39 name_prefix="",
40 separator="_",
41 list_node_names=None,
42 parent_list_node_names=None,
43 rng=None,
44 ):
45 """ """
46 rng = coerce_rng(rng, seed=0)
47 self._obj_gen_weight = None
48 self.separator = separator
49 self.name_prefix = name_prefix
50 self.mat_adjacency = mat_adjacency
51 self._list_node_names = list_node_names
52 self._list_confounder = None
53 self._dict_node_names2ind = {}
54 self._parent_list_node_names = parent_list_node_names
55 self._init_map()
56 self._list_ind_nodes_sorted = None
57 self.rng = rng
59 def _init_map(self):
60 if self._list_node_names is not None:
61 self._dict_node_names2ind = {
62 name: i for (i, name) in enumerate(self._list_node_names)
63 }
65 @property
66 def list_confounder(self):
67 """
68 return list of confounders
69 """
70 nonzero_counts = np.count_nonzero(self.mat_adjacency, axis=0)
71 columns_with_more_than_one = np.where(nonzero_counts > 1)[0]
72 return list(columns_with_more_than_one)
74 def gen_dict_ind2node_na(self, hierarch_na=False):
75 """
76 utility function to have {1:"node_name"} dictionary for plotting
77 hierarch_na: if use hierarchical name maco-node-micro-node format
78 """
79 if hierarch_na:
80 mdict = {i: name for (i, name) in enumerate(self.list_node_names)}
81 elif self._parent_list_node_names is not None:
82 mdict = {
83 i: str(self._parent_list_node_names.index(name))
84 for (i, name) in enumerate(self.list_node_names)
85 }
86 else:
87 # FIXME: if self.list_node_names is list of string,
88 # then it will be ignored, mdict will be {0:integer} style
89 # always
90 mdict = {
91 i: str(self.list_node_names.index(name))
92 for (i, name) in enumerate(self.list_node_names)
93 }
94 return mdict
96 def check(self):
97 """
98 check if the matrix represent a DAG
99 """
100 if not is_dag(self.mat_adjacency):
101 raise RuntimeError("not a DAG")
102 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
103 if not is_dag(binary_adj_mat):
104 raise RuntimeError("not a DAG")
106 @property
107 def num_nodes(self):
108 """
109 number of nodes in DAG
110 """
111 return self.mat_adjacency.shape[0]
113 def gen_node_names(self):
114 """
115 get list of node names
116 """
117 self._list_node_names = [
118 add_prefix(string="v" + str(i), prefix=self.name_prefix)
119 for i in range(self.num_nodes)
120 ]
121 self._init_map()
123 def gen_node_names_stacked(self, dict_macro_node2dag):
124 self._list_node_names = []
125 for key, dag in dict_macro_node2dag.items():
126 self._list_node_names.extend(dag.list_node_names)
127 self._init_map()
129 def get_node_ind(self, node_name):
130 """
131 get the matrix index of the node name
132 """
133 return self._dict_node_names2ind[node_name]
135 @property
136 def list_node_names(self):
137 """
138 get the node names in a linear list
139 """
140 if self._list_node_names is None:
141 self.gen_node_names()
142 return self._list_node_names
144 @property
145 def list_arcs(self):
146 """
147 return the list of edges
148 """
149 list_i_j = list(zip(*self.mat_adjacency.nonzero()))
150 list_arcs = [
151 (self.list_node_names[tuple(ij)[0]],
152 self.list_node_names[tuple(ij)[1]])
153 for ij in list_i_j
154 ]
155 return list_arcs
157 def sample_node(self):
158 """
159 randomly chose a node
160 """
161 ind = self.rng.integers(0, self.mat_adjacency.shape[0])
162 name = self.list_node_names[ind]
163 return self.name_prefix + name, ind
165 def add_arc_ind(self, ind_tail, ind_head, weight=None):
166 """
167 add arc via index of tail and head
168 """
169 node_tail = self.list_node_names[ind_tail]
170 node_head = self.list_node_names[ind_head]
171 self.add_arc(node_tail, node_head, weight)
173 def add_arc(self, node_tail, node_head, weight=None):
174 """
175 add edge to adjacency matrix
176 """
177 ind_tail = self._dict_node_names2ind[node_tail]
178 ind_head = self._dict_node_names2ind[node_head]
179 if weight is None:
180 self.mat_adjacency[ind_tail, ind_head] = 1
181 else:
182 self.mat_adjacency[ind_tail, ind_head] = weight
184 def to_binary_csv(self, benchpress=True, name="adj.csv"):
185 """
186 adjacency matrix to csv format
187 """
188 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
189 if benchpress:
190 binary_adj_mat = np.transpose(binary_adj_mat)
191 df = pd.DataFrame(binary_adj_mat, columns=self.list_node_names)
192 df.to_csv(name, index=False)
194 def topological_sort(self):
195 """
196 topological sort DAG into list of node index
197 """
198 binary_adj_mat = (self.mat_adjacency != 0).astype(int)
199 self._list_ind_nodes_sorted = topological_sort(binary_adj_mat)
200 return self._list_ind_nodes_sorted
202 @property
203 def list_ind_nodes_sorted(self):
204 """
205 get global node index topologically sorted
206 """
207 if self._list_ind_nodes_sorted is None:
208 self.topological_sort()
209 return self._list_ind_nodes_sorted
211 def get_list_parents_inds(self, ind_node):
212 """
213 get list of parents nodes
214 """
215 # np.nonzero(x)
216 # returns (array([0, 1, 2, 2]), array([0, 1, 0, 1]))
217 # assume lower triangular matrix as adjacency matrix
218 # matrix[i, j]=1 indicate arrow j->i
219 submatrix = self.mat_adjacency[ind_node, :]
220 vector = submatrix.flatten()
221 list_inds = np.nonzero(vector)[0].tolist()
222 return list_inds
224 def get_weights_from_list_parents(self, ind_node):
225 """
226 get incoming edge weights
227 """
228 list_parents_inds = self.get_list_parents_inds(ind_node)
229 sub_matrix = self.mat_adjacency[ind_node, list_parents_inds]
230 return sub_matrix
232 def __str__(self):
233 return str(self.mat_adjacency)
235 def __repr__(self):
236 return str(self.mat_adjacency)
238 def subgraph(self, list_ind_unobserved):
239 """
240 subset adjacency matrix by deleting unobserved variables
241 """
242 # delete first axis
243 temp_mat_row = np.delete(
244 self.mat_adjacency, list_ind_unobserved, axis=0)
246 # delete second axis
247 mat_adj_subgraph = np.delete(
248 temp_mat_row, list_ind_unobserved, axis=1)
250 # filter out subgraph node names
251 list_node_names_subgraph = [
252 x
253 for i, x in enumerate(self.list_node_names)
254 if i not in list_ind_unobserved
255 ]
257 subdag = MatDAG(
258 mat_adj_subgraph,
259 list_node_names=list_node_names_subgraph,
260 parent_list_node_names=self.list_node_names,
261 rng=self.rng,
262 )
263 return subdag
265 def visualize(self, title="dag", hierarch_na=False, ax=None,
266 graphviz=False):
267 """
268 draw dag using networkx
269 """
270 draw_dags_nx(
271 self.mat_adjacency,
272 dict_ind2name=self.gen_dict_ind2node_na(hierarch_na),
273 title=title,
274 ax=ax,
275 graphviz=graphviz
276 )
278 @property
279 def list_top_names(self):
280 """
281 return list of node names in toplogical order
282 """
283 list_top_names = [self.list_node_names[i] for i in
284 self.list_ind_nodes_sorted]
285 return list_top_names
287 def global_arbitrary_ind2topind(self, ind_global_arbitrary):
288 ind_top = self.list_ind_nodes_sorted.index(ind_global_arbitrary)
289 return ind_top
291 def top_ind2global_arbitrary(self, ind_top):
292 ind_global_arbitrary = self.list_ind_nodes_sorted[ind_top]
293 return ind_global_arbitrary
295 def get_top_last(self):
296 ind_global_arbitrary = self.list_ind_nodes_sorted[-1]
297 return ind_global_arbitrary
299 def climb(self, ind_arbitrary):
300 ind_top = self.global_arbitrary_ind2topind(ind_arbitrary)
301 if ind_top - 1 >= 0:
302 ind_arbi = self.top_ind2global_arbitrary(ind_top - 1)
303 return ind_arbi
304 return None
306 @property
307 def num_confounder(self):
308 return len(self.list_confounder)
310 @property
311 def list_top_order_sorted_confounder(self):
312 list_top_oder_confounder = [
313 self.list_ind_nodes_sorted.index(confounder)
314 for confounder in self.list_confounder
315 ]
316 list_top_oder_confounder_sorted = sorted(list_top_oder_confounder)
317 return list_top_oder_confounder_sorted