Coverage for causalspyne/dag_viewer.py: 84%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2create different views for the same DAG by hiding some variables 

3""" 

4 

5import warnings 

6 

7import numpy as np 

8from numpy.random import default_rng 

9import pandas as pd 

10 

11from causalspyne.data_gen import DataGen 

12from causalspyne.dag_interface import MatDAG 

13 

14 

15def process_list2hide(list_ind_or_percentage, total_num): 

16 """ 

17 list_ind_or_percentage can either be a list or a scalar float 

18 """ 

19 if len(list_ind_or_percentage) > total_num: 

20 raise RuntimeError( 

21 f"there are {total_num} confounders to hide, less \ 

22 than the length of {list_ind_or_percentage}" 

23 ) 

24 

25 list_ind = [ 

26 min(int(ele * total_num), total_num - 1) 

27 if isinstance(ele, float) else ele 

28 for ele in list_ind_or_percentage 

29 ] 

30 

31 list_ind = list(set(list_ind)) 

32 

33 if max(list_ind) > total_num: 

34 raise RuntimeError( 

35 f"max value in {list_ind_or_percentage} is bigger \ 

36 than total number of variables {total_num} to hide" 

37 ) 

38 

39 return list_ind 

40 

41 

42class DAGView: 

43 """ 

44 with ground truth DAG intact, only show subgraph 

45 """ 

46 

47 def __init__(self, dag, rng=default_rng(0)): 

48 self._dag = dag 

49 # there is no need to use a full DAG to represent subdag 

50 # since sub-dag is not responsible for data generation 

51 self._sub_dag = None 

52 self._data_arr = None 

53 self._subset_data_arr = None 

54 self._list_global_inds_unobserved = None 

55 self._list_global_inds_observed = None 

56 self.data_gen = DataGen(self._dag, rng=rng) 

57 self._list_nodes2hide = None 

58 self._success = False 

59 

60 @property 

61 def dag(self): 

62 return self._dag 

63 

64 def run(self, num_samples, list_nodes2hide=None, confound=False): 

65 """ 

66 generate subgraph adjcency matrix and corresponding data 

67 """ 

68 self._data_arr = self.data_gen.gen(num_samples) 

69 if list_nodes2hide is None: 

70 list_nodes2hide = [0] 

71 if confound: 

72 self.hide_confounder(list_nodes2hide) 

73 else: 

74 self.hide_top_order(list_nodes2hide) 

75 

76 def hide_confounder(self, list_toporder_confounder2hide_input): 

77 """ 

78 given a list of index, hide the confounder according to the toplogical 

79 order provided by the input index list_toporder_confounder2hide 

80 then call self.hide 

81 """ 

82 if not self._dag.list_confounder: 

83 raise RuntimeError( 

84 f"there are no confounders in the graph {self._dag} \ 

85 !" 

86 ) 

87 

88 list_toporder_confounder2hide = process_list2hide( 

89 list_toporder_confounder2hide_input, len(self._dag.list_confounder) 

90 ) 

91 

92 list_ind_confounder_sorted = self._dag.list_top_order_sorted_confounder 

93 

94 list_toporder_confounder_sub = [ 

95 list_ind_confounder_sorted[i] 

96 for i in list_toporder_confounder2hide 

97 ] 

98 

99 self.hide_top_order(list_toporder_confounder_sub) 

100 return True 

101 

102 def hide_top_order(self, list_toporder_unobserved): 

103 """ 

104 hide variables according to a list of global index of topological sort 

105 """ 

106 list_toporder_unobserved = process_list2hide( 

107 list_toporder_unobserved, self._dag.num_nodes 

108 ) 

109 

110 # subset list 

111 self._list_nodes2hide = [ 

112 self._dag.list_top_names[i] for i in list_toporder_unobserved 

113 ] 

114 

115 # FIXME: change to logger 

116 print("nodes to hide " + str(self._list_nodes2hide)) 

117 

118 self._list_global_inds_unobserved = [ 

119 self._dag.list_ind_nodes_sorted[ind_top_order] 

120 for ind_top_order in list_toporder_unobserved 

121 ] 

122 

123 self._sub_dag = self._dag.subgraph(self._list_global_inds_unobserved) 

124 self._subset_data_arr = np.delete( 

125 self._data_arr, self._list_global_inds_unobserved, axis=1 

126 ) 

127 self._success = True 

128 

129 @property 

130 def data(self): 

131 """ 

132 return data in numpy array format 

133 """ 

134 return self._subset_data_arr 

135 

136 @property 

137 def mat_adj(self): 

138 """ 

139 return adj matrix 

140 """ 

141 return self._sub_dag.mat_adjacency 

142 

143 def check_if_subview_done(self): 

144 """ 

145 check if DAG marginalizatino successfull or not 

146 """ 

147 if not self._success: 

148 warnings.warn("no subview of DAG available, exit now!") 

149 return 

150 

151 @property 

152 def node_names(self): 

153 # filter out observed variable 

154 # _node_names = [name for (i, name) in 

155 # enumerate(self._dag.list_node_names) 

156 # if i not in self._list_global_inds_unobserved] 

157 _node_names = self._sub_dag.list_node_names 

158 _node_names_ind = ["X" + str(self._sub_dag._parent_list_node_names.index(name)) for 

159 name in _node_names] 

160 return _node_names_ind 

161 

162 def to_csv(self, title="data_subdag.csv"): 

163 """ 

164 sub dataframe to csv 

165 """ 

166 self.check_if_subview_done() 

167 

168 # FIXME: ensure self.node_names are consistent with self.data 

169 df = pd.DataFrame(self.data, columns=self.node_names) 

170 df.to_csv(title[:-4] + "_" + self.str_node2hide + title[-4:], 

171 index=False) 

172 

173 subdag = MatDAG(self.mat_adj) 

174 subdag.to_binary_csv() 

175 

176 @property 

177 def list_global_inds_nodes2hide(self): 

178 return self._list_global_inds_unobserved 

179 

180 @property 

181 def col_inds(self): 

182 subview_global_inds = \ 

183 [self.dag._dict_node_names2ind[name] 

184 for name in self.dag.list_node_names 

185 if name not in self.str_node2hide] 

186 return subview_global_inds 

187 

188 @property 

189 def list_global_inds_observed(self): 

190 if self._list_global_inds_observed is None: 

191 if self._list_global_inds_unobserved is None: 

192 raise RuntimeError( 

193 "global inds for unobserved not initialized yet!") 

194 self._list_global_inds_observed = \ 

195 [item for item in self._dag.list_ind_nodes_sorted 

196 if item not in self._list_global_inds_unobserved] 

197 return self._list_global_inds_observed 

198 

199 @property 

200 def str_node2hide(self): 

201 """ 

202 string representation of nodes to hide(marginalize) 

203 """ 

204 self.check_if_subview_done() 

205 if self._list_nodes2hide is None: 

206 raise RuntimeError("self._list_node2hide is None!") 

207 _str_node2hide = "_".join(map(str, self._list_nodes2hide)) 

208 

209 _str_node2hide_ind = "_".join(map(str, self._list_global_inds_unobserved)) 

210 

211 return "_ind_".join([_str_node2hide, _str_node2hide_ind]) 

212 

213 def visualize(self, **kwargs): 

214 """ 

215 plot DAG 

216 """ 

217 if not self._success: 

218 warnings.warn("no subview of DAG available") 

219 return 

220 self._sub_dag.visualize(**kwargs)