Coverage for causalspyne/dag_gen.py: 100%
40 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2concrete class to generate simple DAGs
4"""
6import warnings
7from numpy.random import default_rng
9from causalspyne.erdo_renyi_plp import Erdos_Renyi_PLP
10from causalspyne.dag_interface import MatDAG
11from causalspyne.weight import WeightGenWishart
12from causalspyne.dag_manipulator import DAGManipulator
15class GenDAG:
16 def __init__(self, num_nodes, degree, obj_gen_weight=None,
17 rng=default_rng()):
18 """
19 degree: expected degree for each node
20 """
21 self.num_nodes = num_nodes
22 self.degree = degree
23 self.strategy_gen_dag = Erdos_Renyi_PLP(rng)
24 self.obj_gen_weight = obj_gen_weight
25 if obj_gen_weight is None:
26 self.obj_gen_weight = WeightGenWishart(rng=rng)
27 self.dag_manipulator = None
28 self.rng = rng
30 def gen_dag(self, num_nodes=None, prefix="", target_num_confounder=2):
31 """
32 generate DAG and wrap it around with interface
33 """
34 if num_nodes is None:
35 num_nodes = self.num_nodes
36 mat_skeleton = self.strategy_gen_dag(num_nodes, self.degree)
38 mat_mask = (mat_skeleton != 0).astype(float)
39 mat_weight = self.obj_gen_weight.gen(num_nodes)
40 # Hardarmard product
41 mat_weighted_adjacency = mat_mask * mat_weight
43 dag = MatDAG(mat_weighted_adjacency, name_prefix=prefix, rng=self.rng)
44 self.dag_manipulator = DAGManipulator(dag,
45 self.obj_gen_weight, self.rng)
46 ind_arbitrary = dag.get_top_last()
47 counter = 0
48 for _ in range(dag.num_nodes):
49 flag_success = self.dag_manipulator.mk_confound(
50 ind_arbitrary_confound_input=ind_arbitrary)
51 if not flag_success:
52 counter += 1
53 if dag.num_confounder >= target_num_confounder:
54 break
55 # FIXME: it can be the new ind_arbitrary has been tried out already
56 ind_arbitrary = dag.climb(ind_arbitrary)
57 if ind_arbitrary is None:
58 break
59 num_confounder = len(dag.list_confounder)
60 if num_confounder < target_num_confounder and \
61 dag.num_nodes - target_num_confounder > 1:
62 warnings.warn(
63 f"\n failed to ensure {target_num_confounder} confounders for \
64 adjacency matrix \n{dag.mat_adjacency}, \
65 \n after {counter} failed trials, \
66 \n{num_confounder} confounders only")
67 return dag