Coverage for causalspyne/dag_stack_indexer.py: 100%
21 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2index node globally when stacking several dags together
3"""
5from scipy.linalg import block_diag
6from causalspyne.dag_interface import MatDAG
9class DAGStackIndexer:
10 """
11 stack DAG indexer
12 """
14 def __init__(self, host):
15 """ """
16 self.host = host
17 self.host.dag_refined = self.stack_dags()
19 def get_global_ind(self, ind_macro_node, ind_local_node):
20 """
21 get the global index of a local node
22 """
23 return self.list_accum_count[ind_macro_node] + ind_local_node
25 def stack_dags(self):
26 """
27 stack dictionary of DAG into a block diagnoal matrix
28 """
29 mat_stacked_dag = block_diag(
30 *(dag.mat_adjacency for dag in self.host.dict_macro_node2dag.values())
31 )
32 dag_stacked = MatDAG(mat_stacked_dag, rng=self.host.rng)
33 self.dict_num = {
34 key: dag.num_nodes for key, dag in self.host.dict_macro_node2dag.items()
35 }
36 list_accum_count = list(fun_accum_sum(self.dict_num))
37 self.list_accum_count = [0] + list_accum_count[:-1]
38 dag_stacked.gen_node_names_stacked(self.host.dict_macro_node2dag)
39 return dag_stacked
42def fun_accum_sum(dict_num):
43 """
44 count accumulative sum of a dictionary
45 e.g. dict_dag = {'0':3, '1':5, '2':4}
46 """
47 count = 0
48 for num in dict_num.values():
49 count += num
50 yield count