Coverage for src/causalspyne/dag_gen.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2concrete class to generate simple DAGs 

3 

4""" 

5 

6import warnings 

7 

8from causalspyne.erdo_renyi_plp import Erdos_Renyi_PLP 

9from causalspyne.dag_interface import MatDAG 

10from causalspyne.weight import WeightGenWishart 

11from causalspyne.dag_manipulator import DAGManipulator 

12from causalspyne.utils_random import coerce_rng 

13 

14 

15class GenDAG: 

16 def __init__(self, num_nodes, degree, obj_gen_weight=None, 

17 rng=None): 

18 """ 

19 degree: expected degree for each node 

20 """ 

21 rng = coerce_rng(rng) 

22 self.num_nodes = num_nodes 

23 self.degree = degree 

24 self.strategy_gen_dag = Erdos_Renyi_PLP(rng) 

25 self.obj_gen_weight = obj_gen_weight 

26 if obj_gen_weight is None: 

27 self.obj_gen_weight = WeightGenWishart(rng=rng) 

28 self.dag_manipulator = None 

29 self.rng = rng 

30 

31 def gen_dag(self, num_nodes=None, prefix="", *, target_num_confounder): 

32 """ 

33 generate DAG and wrap it around with interface 

34 """ 

35 if num_nodes is None: 

36 num_nodes = self.num_nodes 

37 mat_skeleton = self.strategy_gen_dag(num_nodes, self.degree) 

38 

39 mat_mask = (mat_skeleton != 0).astype(float) 

40 mat_weight = self.obj_gen_weight.gen(num_nodes) 

41 # Hardarmard product 

42 mat_weighted_adjacency = mat_mask * mat_weight 

43 

44 dag = MatDAG(mat_weighted_adjacency, name_prefix=prefix, rng=self.rng) 

45 self.dag_manipulator = DAGManipulator(dag, 

46 self.obj_gen_weight, self.rng) 

47 ind_arbitrary = dag.get_top_last() 

48 counter = 0 

49 for _ in range(dag.num_nodes): 

50 flag_success = self.dag_manipulator.mk_confound( 

51 ind_arbitrary_confound_input=ind_arbitrary) 

52 if not flag_success: 

53 counter += 1 

54 if dag.num_confounder >= target_num_confounder: 

55 break 

56 # FIXME: it can be the new ind_arbitrary has been tried out already 

57 ind_arbitrary = dag.climb(ind_arbitrary) 

58 if ind_arbitrary is None: 

59 break 

60 num_confounder = len(dag.list_confounder) 

61 if num_confounder < target_num_confounder and \ 

62 dag.num_nodes - target_num_confounder > 1: 

63 warnings.warn( 

64 f"\n failed to ensure {target_num_confounder} confounders for \ 

65 adjacency matrix \n{dag.mat_adjacency}, \ 

66 \n after {counter} failed trials, \ 

67 \n{num_confounder} confounders only") 

68 return dag