Coverage for src/causalspyne/dag2ancestral.py: 100%
56 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2turn DAG into ancestral given list of variabels to hide
3"""
5import itertools
6import copy
7import numpy as np
8from causalspyne.utils_closure import ancestor_matrix
11def pairwise_combinations(lst):
12 """
13 generate pairwise combinations from a list
14 """
15 return itertools.combinations(lst, 2)
18def to_binary(matrix):
19 """
20 convert a matrix to 0,1 matrix
21 """
22 binary_matrix = (matrix != 0).astype(int)
23 return binary_matrix
26class DAG2Ancestral:
27 """
28 turn DAG into ancestral given list of variabels to hide
29 """
31 def __init__(self, adj):
32 self.old_adj = copy.deepcopy(adj)
33 self.mat4ancestral = copy.deepcopy(self.old_adj)
34 self.bmat_ancestor = None
36 def pre_cal_n_hop(self):
37 """
38 check if one node is ancestor of another
39 """
40 self.bmat_ancestor = ancestor_matrix(self.old_adj)
42 def run(self, list_hidden):
43 """
44 convert DAG to ancestral
45 """
46 self.pre_cal_n_hop()
47 for hidden in list_hidden:
48 self.deal_children(hidden)
49 self.deal_parent(hidden)
50 # delete first axis
51 temp_mat_row = np.delete(self.mat4ancestral, list_hidden, axis=0)
52 # delete second axis
53 mat_adj_subgraph = np.delete(temp_mat_row, list_hidden, axis=1)
54 self.mat4ancestral = mat_adj_subgraph
55 return to_binary(self.mat4ancestral)
57 def is_ancestor(self, global_ind_node_1, global_ind_node_2):
58 """
59 check if the first argument is an ancestor of the second
60 """
61 flag = self.bmat_ancestor[global_ind_node_2, global_ind_node_1]
62 return flag
64 def deal_parent(self, hidden):
65 """
66 connect parent of hidden and child of hidden
67 """
68 list_parents = self.get_list_parents(hidden)
69 list_children = self.get_list_children(hidden)
70 for global_ind_parent in list_parents:
71 for global_ind_child in list_children:
72 self.mat4ancestral[global_ind_child, global_ind_parent] = 1
74 def deal_children(self, hidden):
75 """
76 for d_1, d_2 in children(hidden) and d_1, d_2 not connected
77 """
78 list_children = self.get_list_children(hidden)
79 if len(list_children) < 2:
80 return
81 for pair in pairwise_combinations(list_children):
82 c1_global_ind, c2_global_ind = pair
83 if self.is_ancestor(c1_global_ind, c2_global_ind):
84 # mat[i,j] means edge from j to i
85 self.mat4ancestral[c2_global_ind, c1_global_ind] = 1
86 elif self.is_ancestor(c2_global_ind, c1_global_ind):
87 self.mat4ancestral[c1_global_ind, c2_global_ind] = 1
88 else:
89 self.mat4ancestral[c1_global_ind, c2_global_ind] = 1
90 self.mat4ancestral[c2_global_ind, c1_global_ind] = 1
92 def get_list_children(self, hidden):
93 """
94 adj[i,j] indicate arrow from j to i
95 """
96 arr = self.old_adj
97 nonzero_indices = np.flatnonzero(arr[:, hidden])
98 # np.nonzero() returns a tuple of arrays.
99 # Each array in this tuple corresponds to a dimension of
100 # the input array and contains the indices of non-zero elements
101 # along that dimension.
102 # nonzero_elements = arr[nonzero_indices, column_index]
103 list_non_zero_indices = nonzero_indices.tolist()
104 return list_non_zero_indices
106 def get_list_parents(self, hidden):
107 """
108 adj[i,j] indicate arrow from j to i
109 """
110 arr = self.old_adj
111 nonzero_indices = np.flatnonzero(arr[hidden, :])
112 # np.nonzero() returns a tuple of arrays.
113 # Each array in this tuple corresponds to a dimension of
114 # the input array and contains the indices of non-zero elements
115 # along that dimension.
116 # nonzero_elements = arr[nonzero_indices, column_index]
117 list_non_zero_indices = nonzero_indices.tolist()
118 return list_non_zero_indices