Coverage for src/causalspyne/dag_viewer.py: 84%

100 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2create different views for the same DAG by hiding some variables 

3""" 

4 

5import warnings 

6 

7import numpy as np 

8import pandas as pd 

9 

10from causalspyne.data_gen import DataGen 

11from causalspyne.dag_interface import MatDAG 

12from causalspyne.utils_random import coerce_rng 

13 

14 

15def process_list2hide(list_ind_or_percentage, total_num): 

16 """ 

17 list_ind_or_percentage can either be a list or a scalar float 

18 """ 

19 if len(list_ind_or_percentage) > total_num: 

20 raise RuntimeError( 

21 f"there are {total_num} confounders to hide, less \ 

22 than the length of {list_ind_or_percentage}" 

23 ) 

24 

25 list_ind = [ 

26 min(int(ele * total_num), total_num - 1) 

27 if isinstance(ele, float) else ele 

28 for ele in list_ind_or_percentage 

29 ] 

30 

31 list_ind = list(set(list_ind)) 

32 

33 if max(list_ind) > total_num: 

34 raise RuntimeError( 

35 f"max value in {list_ind_or_percentage} is bigger \ 

36 than total number of variables {total_num} to hide" 

37 ) 

38 

39 return list_ind 

40 

41 

42class DAGView: 

43 """ 

44 with ground truth DAG intact, only show subgraph 

45 """ 

46 

47 def __init__(self, dag, dft_noise, rng=None): 

48 rng = coerce_rng(rng, seed=0) 

49 self._dag = dag 

50 # there is no need to use a full DAG to represent subdag 

51 # since sub-dag is not responsible for data generation 

52 self._sub_dag = None 

53 self._data_arr = None 

54 self._subset_data_arr = None 

55 self._list_global_inds_unobserved = None 

56 self._list_global_inds_observed = None 

57 self.data_gen = DataGen(dag=self._dag, 

58 dft_noise=dft_noise, 

59 rng=rng) 

60 self._list_nodes2hide = None 

61 self._success = False 

62 

63 @property 

64 def dag(self): 

65 return self._dag 

66 

67 def run(self, num_samples, list_nodes2hide=None, confound=False): 

68 """ 

69 generate subgraph adjcency matrix and corresponding data 

70 """ 

71 self._data_arr = self.data_gen.gen(num_samples) 

72 if list_nodes2hide is None: 

73 list_nodes2hide = [0] 

74 if confound: 

75 self.hide_confounder(list_nodes2hide) 

76 else: 

77 self.hide_top_order(list_nodes2hide) 

78 

79 def hide_confounder(self, list_toporder_confounder2hide_input): 

80 """ 

81 given a list of index, hide the confounder according to the toplogical 

82 order provided by the input index list_toporder_confounder2hide 

83 then call self.hide 

84 """ 

85 if not self._dag.list_confounder: 

86 raise RuntimeError( 

87 f"there are no confounders in the graph {self._dag} \ 

88 !" 

89 ) 

90 

91 list_toporder_confounder2hide = process_list2hide( 

92 list_toporder_confounder2hide_input, len(self._dag.list_confounder) 

93 ) 

94 

95 list_ind_confounder_sorted = self._dag.list_top_order_sorted_confounder 

96 

97 list_toporder_confounder_sub = [ 

98 list_ind_confounder_sorted[i] 

99 for i in list_toporder_confounder2hide 

100 ] 

101 

102 self.hide_top_order(list_toporder_confounder_sub) 

103 return True 

104 

105 def hide_top_order(self, list_toporder_unobserved): 

106 """ 

107 hide variables according to a list of global index of topological sort 

108 """ 

109 list_toporder_unobserved = process_list2hide( 

110 list_toporder_unobserved, self._dag.num_nodes 

111 ) 

112 

113 # subset list 

114 self._list_nodes2hide = [ 

115 self._dag.list_top_names[i] for i in list_toporder_unobserved 

116 ] 

117 

118 # FIXME: change to logger 

119 print("nodes to hide " + str(self._list_nodes2hide)) 

120 

121 self._list_global_inds_unobserved = [ 

122 self._dag.list_ind_nodes_sorted[ind_top_order] 

123 for ind_top_order in list_toporder_unobserved 

124 ] 

125 

126 self._sub_dag = self._dag.subgraph(self._list_global_inds_unobserved) 

127 self._subset_data_arr = np.delete( 

128 self._data_arr, self._list_global_inds_unobserved, axis=1 

129 ) 

130 self._success = True 

131 

132 @property 

133 def data(self): 

134 """ 

135 return data in numpy array format 

136 """ 

137 return self._subset_data_arr 

138 

139 @property 

140 def mat_adj(self): 

141 """ 

142 return adj matrix 

143 """ 

144 return self._sub_dag.mat_adjacency 

145 

146 def check_if_subview_done(self): 

147 """ 

148 check if DAG marginalizatino successfull or not 

149 """ 

150 if not self._success: 

151 warnings.warn("no subview of DAG available, exit now!") 

152 return 

153 

154 @property 

155 def node_names(self): 

156 # filter out observed variable 

157 # _node_names = [name for (i, name) in 

158 # enumerate(self._dag.list_node_names) 

159 # if i not in self._list_global_inds_unobserved] 

160 _node_names = self._sub_dag.list_node_names 

161 _node_names_ind = ["X" + str(self._sub_dag._parent_list_node_names.index(name)) for 

162 name in _node_names] 

163 return _node_names_ind 

164 

165 def to_csv(self, title="data_subdag.csv"): 

166 """ 

167 sub dataframe to csv 

168 """ 

169 self.check_if_subview_done() 

170 

171 # FIXME: ensure self.node_names are consistent with self.data 

172 df = pd.DataFrame(self.data, columns=self.node_names) 

173 df.to_csv(title[:-4] + "_" + self.str_node2hide + title[-4:], 

174 index=False) 

175 

176 subdag = MatDAG(self.mat_adj) 

177 subdag.to_binary_csv() 

178 

179 @property 

180 def list_global_inds_nodes2hide(self): 

181 return self._list_global_inds_unobserved 

182 

183 @property 

184 def col_inds(self): 

185 subview_global_inds = \ 

186 [self.dag._dict_node_names2ind[name] 

187 for name in self.dag.list_node_names 

188 if name not in self.str_node2hide] 

189 return subview_global_inds 

190 

191 @property 

192 def list_global_inds_observed(self): 

193 if self._list_global_inds_observed is None: 

194 if self._list_global_inds_unobserved is None: 

195 raise RuntimeError( 

196 "global inds for unobserved not initialized yet!") 

197 self._list_global_inds_observed = \ 

198 [item for item in self._dag.list_ind_nodes_sorted 

199 if item not in self._list_global_inds_unobserved] 

200 return self._list_global_inds_observed 

201 

202 @property 

203 def str_node2hide(self): 

204 """ 

205 string representation of nodes to hide(marginalize) 

206 """ 

207 self.check_if_subview_done() 

208 if self._list_nodes2hide is None: 

209 raise RuntimeError("self._list_node2hide is None!") 

210 _str_node2hide = "_".join(map(str, self._list_nodes2hide)) 

211 

212 _str_node2hide_ind = "_".join(map(str, self._list_global_inds_unobserved)) 

213 

214 return "_ind_".join([_str_node2hide, _str_node2hide_ind]) 

215 

216 def visualize(self, **kwargs): 

217 """ 

218 plot DAG 

219 """ 

220 if not self._success: 

221 warnings.warn("no subview of DAG available") 

222 return 

223 self._sub_dag.visualize(**kwargs)