Coverage for causalspyne/utils_topological_sort.py: 85%
27 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2toplogical sort on adjacency matrix
3"""
5import numpy as np
7from causalspyne.is_dag import is_dag
10def is_binary_matrix(matrix):
11 """
12 check if matrix is binary
13 """
14 # Convert the input to a NumPy array if it isn't already
15 arr = np.array(matrix)
17 # Check if all elements are either 0 or 1
18 return np.all((arr == 0) | (arr == 1))
21def topological_sort(binary_adj_mat):
22 """
23 first identify all source nodes
25 if binary_adj_mat[node_arrow_head][node_arrow_tail] != 0, then there is
26 edge from node_arrow_tail to node_arrow_head since by default we assume
27 a lower triangular matrix where the first row should be source of the graph
28 since no other nodes point into it.
29 """
30 if not is_dag(binary_adj_mat):
31 raise RuntimeError("not a DAG!")
32 if not is_binary_matrix(binary_adj_mat):
33 raise RuntimeError("input matrix must only have 1, 0 for counting!")
34 num_nodes = len(binary_adj_mat)
35 # np.sum([[0, 1], [0, 5]], axis=0)
36 # array([0, 6])
37 arr_node_in_degree_volatile = np.sum(binary_adj_mat, axis=1)
38 # list_queue_src_node_inds initially only contains all source nodes
39 list_queue_src_node_inds = [
40 i for i in range(num_nodes) if arr_node_in_degree_volatile[i] == 0
41 ]
42 if not list_queue_src_node_inds:
43 raise RuntimeError("no source nodes!")
44 list_sorted_node_inds = []
46 while list_queue_src_node_inds:
47 node_src = list_queue_src_node_inds.pop(0)
48 list_sorted_node_inds.append(node_src)
49 for neighbor in range(num_nodes):
50 if binary_adj_mat[neighbor][node_src] != 0:
51 # arrow: node_src -> neighbor
52 arr_node_in_degree_volatile[neighbor] -= 1
53 if arr_node_in_degree_volatile[neighbor] == 0:
54 list_queue_src_node_inds.append(neighbor)
56 if len(list_sorted_node_inds) != num_nodes:
57 raise ValueError(
58 "sorted nodes not completed!: "
59 + str(list_sorted_node_inds)
60 + str(arr_node_in_degree_volatile)
61 )
63 return list_sorted_node_inds