Coverage for src/causalspyne/dag_interface.py: 97%

145 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2class method for DAG operations and check 

3""" 

4 

5import numpy as np 

6import pandas as pd 

7 

8from causalspyne.is_dag import is_dag 

9from causalspyne.utils_topological_sort import topological_sort 

10from causalspyne.draw_dags import draw_dags_nx 

11from causalspyne.utils_random import coerce_rng 

12 

13 

14def add_prefix(string, prefix="", separator="u"): 

15 """ 

16 Adds a prefix to a string. 

17 If the prefix is empty, returns the original string. 

18 

19 Args: 

20 string (str): The original string. 

21 prefix (str, optional): The prefix to add. Defaults to an empty string. 

22 

23 Returns: 

24 str: The string with the prefix added. 

25 """ 

26 if not prefix: 

27 return string 

28 return separator.join([prefix, string]) 

29 

30 

31class MatDAG: 

32 """ 

33 DAG represented as a mat_adjacency 

34 """ 

35 

36 def __init__( 

37 self, 

38 mat_adjacency, 

39 name_prefix="", 

40 separator="_", 

41 list_node_names=None, 

42 parent_list_node_names=None, 

43 rng=None, 

44 ): 

45 """ """ 

46 rng = coerce_rng(rng, seed=0) 

47 self._obj_gen_weight = None 

48 self.separator = separator 

49 self.name_prefix = name_prefix 

50 self.mat_adjacency = mat_adjacency 

51 self._list_node_names = list_node_names 

52 self._list_confounder = None 

53 self._dict_node_names2ind = {} 

54 self._parent_list_node_names = parent_list_node_names 

55 self._init_map() 

56 self._list_ind_nodes_sorted = None 

57 self.rng = rng 

58 

59 def _init_map(self): 

60 if self._list_node_names is not None: 

61 self._dict_node_names2ind = { 

62 name: i for (i, name) in enumerate(self._list_node_names) 

63 } 

64 

65 @property 

66 def list_confounder(self): 

67 """ 

68 return list of confounders 

69 """ 

70 nonzero_counts = np.count_nonzero(self.mat_adjacency, axis=0) 

71 columns_with_more_than_one = np.where(nonzero_counts > 1)[0] 

72 return list(columns_with_more_than_one) 

73 

74 def gen_dict_ind2node_na(self, hierarch_na=False): 

75 """ 

76 utility function to have {1:"node_name"} dictionary for plotting 

77 hierarch_na: if use hierarchical name maco-node-micro-node format 

78 """ 

79 if hierarch_na: 

80 mdict = {i: name for (i, name) in enumerate(self.list_node_names)} 

81 elif self._parent_list_node_names is not None: 

82 mdict = { 

83 i: str(self._parent_list_node_names.index(name)) 

84 for (i, name) in enumerate(self.list_node_names) 

85 } 

86 else: 

87 # FIXME: if self.list_node_names is list of string, 

88 # then it will be ignored, mdict will be {0:integer} style 

89 # always 

90 mdict = { 

91 i: str(self.list_node_names.index(name)) 

92 for (i, name) in enumerate(self.list_node_names) 

93 } 

94 return mdict 

95 

96 def check(self): 

97 """ 

98 check if the matrix represent a DAG 

99 """ 

100 if not is_dag(self.mat_adjacency): 

101 raise RuntimeError("not a DAG") 

102 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

103 if not is_dag(binary_adj_mat): 

104 raise RuntimeError("not a DAG") 

105 

106 @property 

107 def num_nodes(self): 

108 """ 

109 number of nodes in DAG 

110 """ 

111 return self.mat_adjacency.shape[0] 

112 

113 def gen_node_names(self): 

114 """ 

115 get list of node names 

116 """ 

117 self._list_node_names = [ 

118 add_prefix(string="v" + str(i), prefix=self.name_prefix) 

119 for i in range(self.num_nodes) 

120 ] 

121 self._init_map() 

122 

123 def gen_node_names_stacked(self, dict_macro_node2dag): 

124 self._list_node_names = [] 

125 for key, dag in dict_macro_node2dag.items(): 

126 self._list_node_names.extend(dag.list_node_names) 

127 self._init_map() 

128 

129 def get_node_ind(self, node_name): 

130 """ 

131 get the matrix index of the node name 

132 """ 

133 return self._dict_node_names2ind[node_name] 

134 

135 @property 

136 def list_node_names(self): 

137 """ 

138 get the node names in a linear list 

139 """ 

140 if self._list_node_names is None: 

141 self.gen_node_names() 

142 return self._list_node_names 

143 

144 @property 

145 def list_arcs(self): 

146 """ 

147 return the list of edges 

148 """ 

149 list_i_j = list(zip(*self.mat_adjacency.nonzero())) 

150 list_arcs = [ 

151 (self.list_node_names[tuple(ij)[0]], 

152 self.list_node_names[tuple(ij)[1]]) 

153 for ij in list_i_j 

154 ] 

155 return list_arcs 

156 

157 def sample_node(self): 

158 """ 

159 randomly chose a node 

160 """ 

161 ind = self.rng.integers(0, self.mat_adjacency.shape[0]) 

162 name = self.list_node_names[ind] 

163 return self.name_prefix + name, ind 

164 

165 def add_arc_ind(self, ind_tail, ind_head, weight=None): 

166 """ 

167 add arc via index of tail and head 

168 """ 

169 node_tail = self.list_node_names[ind_tail] 

170 node_head = self.list_node_names[ind_head] 

171 self.add_arc(node_tail, node_head, weight) 

172 

173 def add_arc(self, node_tail, node_head, weight=None): 

174 """ 

175 add edge to adjacency matrix 

176 """ 

177 ind_tail = self._dict_node_names2ind[node_tail] 

178 ind_head = self._dict_node_names2ind[node_head] 

179 if weight is None: 

180 self.mat_adjacency[ind_tail, ind_head] = 1 

181 else: 

182 self.mat_adjacency[ind_tail, ind_head] = weight 

183 

184 def to_binary_csv(self, benchpress=True, name="adj.csv"): 

185 """ 

186 adjacency matrix to csv format 

187 """ 

188 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

189 if benchpress: 

190 binary_adj_mat = np.transpose(binary_adj_mat) 

191 df = pd.DataFrame(binary_adj_mat, columns=self.list_node_names) 

192 df.to_csv(name, index=False) 

193 

194 def topological_sort(self): 

195 """ 

196 topological sort DAG into list of node index 

197 """ 

198 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

199 self._list_ind_nodes_sorted = topological_sort(binary_adj_mat) 

200 return self._list_ind_nodes_sorted 

201 

202 @property 

203 def list_ind_nodes_sorted(self): 

204 """ 

205 get global node index topologically sorted 

206 """ 

207 if self._list_ind_nodes_sorted is None: 

208 self.topological_sort() 

209 return self._list_ind_nodes_sorted 

210 

211 def get_list_parents_inds(self, ind_node): 

212 """ 

213 get list of parents nodes 

214 """ 

215 # np.nonzero(x) 

216 # returns (array([0, 1, 2, 2]), array([0, 1, 0, 1])) 

217 # assume lower triangular matrix as adjacency matrix 

218 # matrix[i, j]=1 indicate arrow j->i 

219 submatrix = self.mat_adjacency[ind_node, :] 

220 vector = submatrix.flatten() 

221 list_inds = np.nonzero(vector)[0].tolist() 

222 return list_inds 

223 

224 def get_weights_from_list_parents(self, ind_node): 

225 """ 

226 get incoming edge weights 

227 """ 

228 list_parents_inds = self.get_list_parents_inds(ind_node) 

229 sub_matrix = self.mat_adjacency[ind_node, list_parents_inds] 

230 return sub_matrix 

231 

232 def __str__(self): 

233 return str(self.mat_adjacency) 

234 

235 def __repr__(self): 

236 return str(self.mat_adjacency) 

237 

238 def subgraph(self, list_ind_unobserved): 

239 """ 

240 subset adjacency matrix by deleting unobserved variables 

241 """ 

242 # delete first axis 

243 temp_mat_row = np.delete( 

244 self.mat_adjacency, list_ind_unobserved, axis=0) 

245 

246 # delete second axis 

247 mat_adj_subgraph = np.delete( 

248 temp_mat_row, list_ind_unobserved, axis=1) 

249 

250 # filter out subgraph node names 

251 list_node_names_subgraph = [ 

252 x 

253 for i, x in enumerate(self.list_node_names) 

254 if i not in list_ind_unobserved 

255 ] 

256 

257 subdag = MatDAG( 

258 mat_adj_subgraph, 

259 list_node_names=list_node_names_subgraph, 

260 parent_list_node_names=self.list_node_names, 

261 rng=self.rng, 

262 ) 

263 return subdag 

264 

265 def visualize(self, title="dag", hierarch_na=False, ax=None, 

266 graphviz=False): 

267 """ 

268 draw dag using networkx 

269 """ 

270 draw_dags_nx( 

271 self.mat_adjacency, 

272 dict_ind2name=self.gen_dict_ind2node_na(hierarch_na), 

273 title=title, 

274 ax=ax, 

275 graphviz=graphviz 

276 ) 

277 

278 @property 

279 def list_top_names(self): 

280 """ 

281 return list of node names in toplogical order 

282 """ 

283 list_top_names = [self.list_node_names[i] for i in 

284 self.list_ind_nodes_sorted] 

285 return list_top_names 

286 

287 def global_arbitrary_ind2topind(self, ind_global_arbitrary): 

288 ind_top = self.list_ind_nodes_sorted.index(ind_global_arbitrary) 

289 return ind_top 

290 

291 def top_ind2global_arbitrary(self, ind_top): 

292 ind_global_arbitrary = self.list_ind_nodes_sorted[ind_top] 

293 return ind_global_arbitrary 

294 

295 def get_top_last(self): 

296 ind_global_arbitrary = self.list_ind_nodes_sorted[-1] 

297 return ind_global_arbitrary 

298 

299 def climb(self, ind_arbitrary): 

300 ind_top = self.global_arbitrary_ind2topind(ind_arbitrary) 

301 if ind_top - 1 >= 0: 

302 ind_arbi = self.top_ind2global_arbitrary(ind_top - 1) 

303 return ind_arbi 

304 return None 

305 

306 @property 

307 def num_confounder(self): 

308 return len(self.list_confounder) 

309 

310 @property 

311 def list_top_order_sorted_confounder(self): 

312 list_top_oder_confounder = [ 

313 self.list_ind_nodes_sorted.index(confounder) 

314 for confounder in self.list_confounder 

315 ] 

316 list_top_oder_confounder_sorted = sorted(list_top_oder_confounder) 

317 return list_top_oder_confounder_sorted