Coverage for causalspyne/dag_interface.py: 97%

144 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2class method for DAG operations and check 

3""" 

4 

5import numpy as np 

6from numpy.random import default_rng 

7import pandas as pd 

8 

9from causalspyne.is_dag import is_dag 

10from causalspyne.utils_topological_sort import topological_sort 

11from causalspyne.draw_dags import draw_dags_nx 

12 

13 

14def add_prefix(string, prefix="", separator="u"): 

15 """ 

16 Adds a prefix to a string. 

17 If the prefix is empty, returns the original string. 

18 

19 Args: 

20 string (str): The original string. 

21 prefix (str, optional): The prefix to add. Defaults to an empty string. 

22 

23 Returns: 

24 str: The string with the prefix added. 

25 """ 

26 if not prefix: 

27 return string 

28 return separator.join([prefix, string]) 

29 

30 

31class MatDAG: 

32 """ 

33 DAG represented as a mat_adjacency 

34 """ 

35 

36 def __init__( 

37 self, 

38 mat_adjacency, 

39 name_prefix="", 

40 separator="_", 

41 list_node_names=None, 

42 parent_list_node_names=None, 

43 rng=default_rng(0), 

44 ): 

45 """ """ 

46 self._obj_gen_weight = None 

47 self.separator = separator 

48 self.name_prefix = name_prefix 

49 self.mat_adjacency = mat_adjacency 

50 self._list_node_names = list_node_names 

51 self._list_confounder = None 

52 self._dict_node_names2ind = {} 

53 self._parent_list_node_names = parent_list_node_names 

54 self._init_map() 

55 self._list_ind_nodes_sorted = None 

56 self.rng = rng 

57 

58 def _init_map(self): 

59 if self._list_node_names is not None: 

60 self._dict_node_names2ind = { 

61 name: i for (i, name) in enumerate(self._list_node_names) 

62 } 

63 

64 @property 

65 def list_confounder(self): 

66 """ 

67 return list of confounders 

68 """ 

69 nonzero_counts = np.count_nonzero(self.mat_adjacency, axis=0) 

70 columns_with_more_than_one = np.where(nonzero_counts > 1)[0] 

71 return list(columns_with_more_than_one) 

72 

73 def gen_dict_ind2node_na(self, hierarch_na=False): 

74 """ 

75 utility function to have {1:"node_name"} dictionary for plotting 

76 hierarch_na: if use hierarchical name maco-node-micro-node format 

77 """ 

78 if hierarch_na: 

79 mdict = {i: name for (i, name) in enumerate(self.list_node_names)} 

80 elif self._parent_list_node_names is not None: 

81 mdict = { 

82 i: str(self._parent_list_node_names.index(name)) 

83 for (i, name) in enumerate(self.list_node_names) 

84 } 

85 else: 

86 # FIXME: if self.list_node_names is list of string, 

87 # then it will be ignored, mdict will be {0:integer} style 

88 # always 

89 mdict = { 

90 i: str(self.list_node_names.index(name)) 

91 for (i, name) in enumerate(self.list_node_names) 

92 } 

93 return mdict 

94 

95 def check(self): 

96 """ 

97 check if the matrix represent a DAG 

98 """ 

99 if not is_dag(self.mat_adjacency): 

100 raise RuntimeError("not a DAG") 

101 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

102 if not is_dag(binary_adj_mat): 

103 raise RuntimeError("not a DAG") 

104 

105 @property 

106 def num_nodes(self): 

107 """ 

108 number of nodes in DAG 

109 """ 

110 return self.mat_adjacency.shape[0] 

111 

112 def gen_node_names(self): 

113 """ 

114 get list of node names 

115 """ 

116 self._list_node_names = [ 

117 add_prefix(string="v" + str(i), prefix=self.name_prefix) 

118 for i in range(self.num_nodes) 

119 ] 

120 self._init_map() 

121 

122 def gen_node_names_stacked(self, dict_macro_node2dag): 

123 self._list_node_names = [] 

124 for key, dag in dict_macro_node2dag.items(): 

125 self._list_node_names.extend(dag.list_node_names) 

126 self._init_map() 

127 

128 def get_node_ind(self, node_name): 

129 """ 

130 get the matrix index of the node name 

131 """ 

132 return self._dict_node_names2ind[node_name] 

133 

134 @property 

135 def list_node_names(self): 

136 """ 

137 get the node names in a linear list 

138 """ 

139 if self._list_node_names is None: 

140 self.gen_node_names() 

141 return self._list_node_names 

142 

143 @property 

144 def list_arcs(self): 

145 """ 

146 return the list of edges 

147 """ 

148 list_i_j = list(zip(*self.mat_adjacency.nonzero())) 

149 list_arcs = [ 

150 (self.list_node_names[tuple(ij)[0]], 

151 self.list_node_names[tuple(ij)[1]]) 

152 for ij in list_i_j 

153 ] 

154 return list_arcs 

155 

156 def sample_node(self): 

157 """ 

158 randomly chose a node 

159 """ 

160 ind = self.rng.integers(0, self.mat_adjacency.shape[0]) 

161 name = self.list_node_names[ind] 

162 return self.name_prefix + name, ind 

163 

164 def add_arc_ind(self, ind_tail, ind_head, weight=None): 

165 """ 

166 add arc via index of tail and head 

167 """ 

168 node_tail = self.list_node_names[ind_tail] 

169 node_head = self.list_node_names[ind_head] 

170 self.add_arc(node_tail, node_head, weight) 

171 

172 def add_arc(self, node_tail, node_head, weight=None): 

173 """ 

174 add edge to adjacency matrix 

175 """ 

176 ind_tail = self._dict_node_names2ind[node_tail] 

177 ind_head = self._dict_node_names2ind[node_head] 

178 if weight is None: 

179 self.mat_adjacency[ind_tail, ind_head] = 1 

180 else: 

181 self.mat_adjacency[ind_tail, ind_head] = weight 

182 

183 def to_binary_csv(self, benchpress=True, name="adj.csv"): 

184 """ 

185 adjacency matrix to csv format 

186 """ 

187 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

188 if benchpress: 

189 binary_adj_mat = np.transpose(binary_adj_mat) 

190 df = pd.DataFrame(binary_adj_mat, columns=self.list_node_names) 

191 df.to_csv(name, index=False) 

192 

193 def topological_sort(self): 

194 """ 

195 topological sort DAG into list of node index 

196 """ 

197 binary_adj_mat = (self.mat_adjacency != 0).astype(int) 

198 self._list_ind_nodes_sorted = topological_sort(binary_adj_mat) 

199 return self._list_ind_nodes_sorted 

200 

201 @property 

202 def list_ind_nodes_sorted(self): 

203 """ 

204 get global node index topologically sorted 

205 """ 

206 if self._list_ind_nodes_sorted is None: 

207 self.topological_sort() 

208 return self._list_ind_nodes_sorted 

209 

210 def get_list_parents_inds(self, ind_node): 

211 """ 

212 get list of parents nodes 

213 """ 

214 # np.nonzero(x) 

215 # returns (array([0, 1, 2, 2]), array([0, 1, 0, 1])) 

216 # assume lower triangular matrix as adjacency matrix 

217 # matrix[i, j]=1 indicate arrow j->i 

218 submatrix = self.mat_adjacency[ind_node, :] 

219 vector = submatrix.flatten() 

220 list_inds = np.nonzero(vector)[0].tolist() 

221 return list_inds 

222 

223 def get_weights_from_list_parents(self, ind_node): 

224 """ 

225 get incoming edge weights 

226 """ 

227 list_parents_inds = self.get_list_parents_inds(ind_node) 

228 sub_matrix = self.mat_adjacency[ind_node, list_parents_inds] 

229 return sub_matrix 

230 

231 def __str__(self): 

232 return str(self.mat_adjacency) 

233 

234 def __repr__(self): 

235 return str(self.mat_adjacency) 

236 

237 def subgraph(self, list_ind_unobserved): 

238 """ 

239 subset adjacency matrix by deleting unobserved variables 

240 """ 

241 # delete first axis 

242 temp_mat_row = np.delete( 

243 self.mat_adjacency, list_ind_unobserved, axis=0) 

244 

245 # delete second axis 

246 mat_adj_subgraph = np.delete( 

247 temp_mat_row, list_ind_unobserved, axis=1) 

248 

249 # filter out subgraph node names 

250 list_node_names_subgraph = [ 

251 x 

252 for i, x in enumerate(self.list_node_names) 

253 if i not in list_ind_unobserved 

254 ] 

255 

256 subdag = MatDAG( 

257 mat_adj_subgraph, 

258 list_node_names=list_node_names_subgraph, 

259 parent_list_node_names=self.list_node_names, 

260 rng=self.rng, 

261 ) 

262 return subdag 

263 

264 def visualize(self, title="dag", hierarch_na=False, ax=None, 

265 graphviz=False): 

266 """ 

267 draw dag using networkx 

268 """ 

269 draw_dags_nx( 

270 self.mat_adjacency, 

271 dict_ind2name=self.gen_dict_ind2node_na(hierarch_na), 

272 title=title, 

273 ax=ax, 

274 graphviz=graphviz 

275 ) 

276 

277 @property 

278 def list_top_names(self): 

279 """ 

280 return list of node names in toplogical order 

281 """ 

282 list_top_names = [self.list_node_names[i] for i in 

283 self.list_ind_nodes_sorted] 

284 return list_top_names 

285 

286 def global_arbitrary_ind2topind(self, ind_global_arbitrary): 

287 ind_top = self.list_ind_nodes_sorted.index(ind_global_arbitrary) 

288 return ind_top 

289 

290 def top_ind2global_arbitrary(self, ind_top): 

291 ind_global_arbitrary = self.list_ind_nodes_sorted[ind_top] 

292 return ind_global_arbitrary 

293 

294 def get_top_last(self): 

295 ind_global_arbitrary = self.list_ind_nodes_sorted[-1] 

296 return ind_global_arbitrary 

297 

298 def climb(self, ind_arbitrary): 

299 ind_top = self.global_arbitrary_ind2topind(ind_arbitrary) 

300 if ind_top - 1 >= 0: 

301 ind_arbi = self.top_ind2global_arbitrary(ind_top - 1) 

302 return ind_arbi 

303 return None 

304 

305 @property 

306 def num_confounder(self): 

307 return len(self.list_confounder) 

308 

309 @property 

310 def list_top_order_sorted_confounder(self): 

311 list_top_oder_confounder = [ 

312 self.list_ind_nodes_sorted.index(confounder) 

313 for confounder in self.list_confounder 

314 ] 

315 list_top_oder_confounder_sorted = sorted(list_top_oder_confounder) 

316 return list_top_oder_confounder_sorted