Coverage for src/causalspyne/dag2ancestral.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2turn DAG into ancestral given list of variabels to hide 

3""" 

4 

5import itertools 

6import copy 

7import numpy as np 

8from causalspyne.utils_closure import ancestor_matrix 

9 

10 

11def pairwise_combinations(lst): 

12 """ 

13 generate pairwise combinations from a list 

14 """ 

15 return itertools.combinations(lst, 2) 

16 

17 

18def to_binary(matrix): 

19 """ 

20 convert a matrix to 0,1 matrix 

21 """ 

22 binary_matrix = (matrix != 0).astype(int) 

23 return binary_matrix 

24 

25 

26class DAG2Ancestral: 

27 """ 

28 turn DAG into ancestral given list of variabels to hide 

29 """ 

30 

31 def __init__(self, adj): 

32 self.old_adj = copy.deepcopy(adj) 

33 self.mat4ancestral = copy.deepcopy(self.old_adj) 

34 self.bmat_ancestor = None 

35 

36 def pre_cal_n_hop(self): 

37 """ 

38 check if one node is ancestor of another 

39 """ 

40 self.bmat_ancestor = ancestor_matrix(self.old_adj) 

41 

42 def run(self, list_hidden): 

43 """ 

44 convert DAG to ancestral 

45 """ 

46 self.pre_cal_n_hop() 

47 for hidden in list_hidden: 

48 self.deal_children(hidden) 

49 self.deal_parent(hidden) 

50 # delete first axis 

51 temp_mat_row = np.delete(self.mat4ancestral, list_hidden, axis=0) 

52 # delete second axis 

53 mat_adj_subgraph = np.delete(temp_mat_row, list_hidden, axis=1) 

54 self.mat4ancestral = mat_adj_subgraph 

55 return to_binary(self.mat4ancestral) 

56 

57 def is_ancestor(self, global_ind_node_1, global_ind_node_2): 

58 """ 

59 check if the first argument is an ancestor of the second 

60 """ 

61 flag = self.bmat_ancestor[global_ind_node_2, global_ind_node_1] 

62 return flag 

63 

64 def deal_parent(self, hidden): 

65 """ 

66 connect parent of hidden and child of hidden 

67 """ 

68 list_parents = self.get_list_parents(hidden) 

69 list_children = self.get_list_children(hidden) 

70 for global_ind_parent in list_parents: 

71 for global_ind_child in list_children: 

72 self.mat4ancestral[global_ind_child, global_ind_parent] = 1 

73 

74 def deal_children(self, hidden): 

75 """ 

76 for d_1, d_2 in children(hidden) and d_1, d_2 not connected 

77 """ 

78 list_children = self.get_list_children(hidden) 

79 if len(list_children) < 2: 

80 return 

81 for pair in pairwise_combinations(list_children): 

82 c1_global_ind, c2_global_ind = pair 

83 if self.is_ancestor(c1_global_ind, c2_global_ind): 

84 # mat[i,j] means edge from j to i 

85 self.mat4ancestral[c2_global_ind, c1_global_ind] = 1 

86 elif self.is_ancestor(c2_global_ind, c1_global_ind): 

87 self.mat4ancestral[c1_global_ind, c2_global_ind] = 1 

88 else: 

89 self.mat4ancestral[c1_global_ind, c2_global_ind] = 1 

90 self.mat4ancestral[c2_global_ind, c1_global_ind] = 1 

91 

92 def get_list_children(self, hidden): 

93 """ 

94 adj[i,j] indicate arrow from j to i 

95 """ 

96 arr = self.old_adj 

97 nonzero_indices = np.flatnonzero(arr[:, hidden]) 

98 # np.nonzero() returns a tuple of arrays. 

99 # Each array in this tuple corresponds to a dimension of 

100 # the input array and contains the indices of non-zero elements 

101 # along that dimension. 

102 # nonzero_elements = arr[nonzero_indices, column_index] 

103 list_non_zero_indices = nonzero_indices.tolist() 

104 return list_non_zero_indices 

105 

106 def get_list_parents(self, hidden): 

107 """ 

108 adj[i,j] indicate arrow from j to i 

109 """ 

110 arr = self.old_adj 

111 nonzero_indices = np.flatnonzero(arr[hidden, :]) 

112 # np.nonzero() returns a tuple of arrays. 

113 # Each array in this tuple corresponds to a dimension of 

114 # the input array and contains the indices of non-zero elements 

115 # along that dimension. 

116 # nonzero_elements = arr[nonzero_indices, column_index] 

117 list_non_zero_indices = nonzero_indices.tolist() 

118 return list_non_zero_indices