Coverage for causalspyne/dag_stack_indexer.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2index node globally when stacking several dags together 

3""" 

4 

5from scipy.linalg import block_diag 

6from causalspyne.dag_interface import MatDAG 

7 

8 

9class DAGStackIndexer: 

10 """ 

11 stack DAG indexer 

12 """ 

13 

14 def __init__(self, host): 

15 """ """ 

16 self.host = host 

17 self.host.dag_refined = self.stack_dags() 

18 

19 def get_global_ind(self, ind_macro_node, ind_local_node): 

20 """ 

21 get the global index of a local node 

22 """ 

23 return self.list_accum_count[ind_macro_node] + ind_local_node 

24 

25 def stack_dags(self): 

26 """ 

27 stack dictionary of DAG into a block diagnoal matrix 

28 """ 

29 mat_stacked_dag = block_diag( 

30 *(dag.mat_adjacency for dag in self.host.dict_macro_node2dag.values()) 

31 ) 

32 dag_stacked = MatDAG(mat_stacked_dag, rng=self.host.rng) 

33 self.dict_num = { 

34 key: dag.num_nodes for key, dag in self.host.dict_macro_node2dag.items() 

35 } 

36 list_accum_count = list(fun_accum_sum(self.dict_num)) 

37 self.list_accum_count = [0] + list_accum_count[:-1] 

38 dag_stacked.gen_node_names_stacked(self.host.dict_macro_node2dag) 

39 return dag_stacked 

40 

41 

42def fun_accum_sum(dict_num): 

43 """ 

44 count accumulative sum of a dictionary 

45 e.g. dict_dag = {'0':3, '1':5, '2':4} 

46 """ 

47 count = 0 

48 for num in dict_num.values(): 

49 count += num 

50 yield count