Coverage for src/causalspyne/utils_causallearn_g2ancestral.py: 0%
23 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2A graph is simpler than simplicial complex in that it only characterize
3pairwise relationship, which enables us to project a complicated graph in
4causallearn to ancestral ADMG (i.e. ancestral graph with only directed edges)
6We of course lose information via projection.
8This module offers another function to extract causal order from
9causal-learn
10"""
11from causallearn.graph.NodeType import NodeType
14def project_causallearn_g2ancestral_admg(mat_graph_causallearn):
15 """
16 A general graph in causal-learn is not necessarily a PAG (it does have
17 an attribute to denote whether it is a PAG or not)
18 When it is a PAG, like the return from FCI:
19 https://causal-learn.readthedocs.io/en/latest/search_methods_index/Constraint-based%20causal%20discovery%20methods/FCI.html#usage
20 g.graph is a PAG
21 A->B: G[A, B]=-1, G[B, A]=1, to our format, G[A, B]=0, G[B, A]=1
22 A.->B: G[A, B]=2, G[B, A]=1, B is not an ancestor of A,
23 to our format, G[A, B]=0, G[B, A]=1
24 A.-.B: G[A, B]=2, G[B,A]=2, no set d-separates A and B, ?????
25 A<->B: there is common latent cause, G[A, B]=1, G[B, A]=1,
26 no need to convert
27 """
28 for i in range(len(mat_graph_causallearn)):
29 for j in range(len(mat_graph_causallearn[i])):
30 if mat_graph_causallearn[i][j] != 1:
31 mat_graph_causallearn[i][j] = 0
32 return mat_graph_causallearn
35def get_causalearn_order(g_causal_learn, node_names=None):
36 """
37 get causal order from causal learn graph using api like
38 'node.get_node_type() == NodeType.LATENT:'
39 """
40 real_name_order = []
41 real_na_order_latent = []
42 nodes_order = g_causal_learn.get_causal_ordering()
43 for node in nodes_order:
44 if node.get_node_type() == NodeType.LATENT:
45 # GIN-LINLAM
46 real_na_order_latent.append(node.get_node_type)
47 continue
48 fake_name = node.get_name()
49 index = int(fake_name.removeprefix("X")) - 1
50 real_name = node_names[index]
51 print(f"{real_name}")
52 real_name_order.append(real_name)
53 real_na_order_latent.append(real_name)
54 print(f"{real_name_order}")
55 return real_name_order, real_na_order_latent