Coverage for causalspyne/dag_gen.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2concrete class to generate simple DAGs 

3 

4""" 

5 

6import warnings 

7from numpy.random import default_rng 

8 

9from causalspyne.erdo_renyi_plp import Erdos_Renyi_PLP 

10from causalspyne.dag_interface import MatDAG 

11from causalspyne.weight import WeightGenWishart 

12from causalspyne.dag_manipulator import DAGManipulator 

13 

14 

15class GenDAG: 

16 def __init__(self, num_nodes, degree, obj_gen_weight=None, 

17 rng=default_rng()): 

18 """ 

19 degree: expected degree for each node 

20 """ 

21 self.num_nodes = num_nodes 

22 self.degree = degree 

23 self.strategy_gen_dag = Erdos_Renyi_PLP(rng) 

24 self.obj_gen_weight = obj_gen_weight 

25 if obj_gen_weight is None: 

26 self.obj_gen_weight = WeightGenWishart(rng=rng) 

27 self.dag_manipulator = None 

28 self.rng = rng 

29 

30 def gen_dag(self, num_nodes=None, prefix="", target_num_confounder=2): 

31 """ 

32 generate DAG and wrap it around with interface 

33 """ 

34 if num_nodes is None: 

35 num_nodes = self.num_nodes 

36 mat_skeleton = self.strategy_gen_dag(num_nodes, self.degree) 

37 

38 mat_mask = (mat_skeleton != 0).astype(float) 

39 mat_weight = self.obj_gen_weight.gen(num_nodes) 

40 # Hardarmard product 

41 mat_weighted_adjacency = mat_mask * mat_weight 

42 

43 dag = MatDAG(mat_weighted_adjacency, name_prefix=prefix, rng=self.rng) 

44 self.dag_manipulator = DAGManipulator(dag, 

45 self.obj_gen_weight, self.rng) 

46 ind_arbitrary = dag.get_top_last() 

47 counter = 0 

48 for _ in range(dag.num_nodes): 

49 flag_success = self.dag_manipulator.mk_confound( 

50 ind_arbitrary_confound_input=ind_arbitrary) 

51 if not flag_success: 

52 counter += 1 

53 if dag.num_confounder >= target_num_confounder: 

54 break 

55 # FIXME: it can be the new ind_arbitrary has been tried out already 

56 ind_arbitrary = dag.climb(ind_arbitrary) 

57 if ind_arbitrary is None: 

58 break 

59 num_confounder = len(dag.list_confounder) 

60 if num_confounder < target_num_confounder and \ 

61 dag.num_nodes - target_num_confounder > 1: 

62 warnings.warn( 

63 f"\n failed to ensure {target_num_confounder} confounders for \ 

64 adjacency matrix \n{dag.mat_adjacency}, \ 

65 \n after {counter} failed trials, \ 

66 \n{num_confounder} confounders only") 

67 return dag