Coverage for causalspyne/utils_topological_sort.py: 85%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2toplogical sort on adjacency matrix 

3""" 

4 

5import numpy as np 

6 

7from causalspyne.is_dag import is_dag 

8 

9 

10def is_binary_matrix(matrix): 

11 """ 

12 check if matrix is binary 

13 """ 

14 # Convert the input to a NumPy array if it isn't already 

15 arr = np.array(matrix) 

16 

17 # Check if all elements are either 0 or 1 

18 return np.all((arr == 0) | (arr == 1)) 

19 

20 

21def topological_sort(binary_adj_mat): 

22 """ 

23 first identify all source nodes 

24 

25 if binary_adj_mat[node_arrow_head][node_arrow_tail] != 0, then there is 

26 edge from node_arrow_tail to node_arrow_head since by default we assume 

27 a lower triangular matrix where the first row should be source of the graph 

28 since no other nodes point into it. 

29 """ 

30 if not is_dag(binary_adj_mat): 

31 raise RuntimeError("not a DAG!") 

32 if not is_binary_matrix(binary_adj_mat): 

33 raise RuntimeError("input matrix must only have 1, 0 for counting!") 

34 num_nodes = len(binary_adj_mat) 

35 # np.sum([[0, 1], [0, 5]], axis=0) 

36 # array([0, 6]) 

37 arr_node_in_degree_volatile = np.sum(binary_adj_mat, axis=1) 

38 # list_queue_src_node_inds initially only contains all source nodes 

39 list_queue_src_node_inds = [ 

40 i for i in range(num_nodes) if arr_node_in_degree_volatile[i] == 0 

41 ] 

42 if not list_queue_src_node_inds: 

43 raise RuntimeError("no source nodes!") 

44 list_sorted_node_inds = [] 

45 

46 while list_queue_src_node_inds: 

47 node_src = list_queue_src_node_inds.pop(0) 

48 list_sorted_node_inds.append(node_src) 

49 for neighbor in range(num_nodes): 

50 if binary_adj_mat[neighbor][node_src] != 0: 

51 # arrow: node_src -> neighbor 

52 arr_node_in_degree_volatile[neighbor] -= 1 

53 if arr_node_in_degree_volatile[neighbor] == 0: 

54 list_queue_src_node_inds.append(neighbor) 

55 

56 if len(list_sorted_node_inds) != num_nodes: 

57 raise ValueError( 

58 "sorted nodes not completed!: " 

59 + str(list_sorted_node_inds) 

60 + str(arr_node_in_degree_volatile) 

61 ) 

62 

63 return list_sorted_node_inds