Coverage for src/causalspyne/draw_dags.py: 85%
26 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-05-15 16:30 +0000
1"""
2draw DAG using networkx
3"""
5import matplotlib.pyplot as plt
6import networkx as nx
7try:
8 import pygraphviz
9 has_graphviz = True
10 from networkx.drawing.nx_agraph import graphviz_layout
11except ImportError:
12 graphviz_layout = None
13 has_graphviz = False
17def draw_dags_nx(
18 adj_matrix, dict_ind2name=None, title="dag", ax=None, show=False,
19 graphviz=False
20):
21 """
22 networkx adjacency matrix (i,j) entry refers to edge from i pointing to j,
23 which is opposite to the CausalSpyne convention
24 """
25 graphviz = has_graphviz and graphviz
26 plt.close("all")
27 nx_graph = nx.from_numpy_array(adj_matrix.transpose(),
28 create_using=nx.DiGraph)
29 if dict_ind2name:
30 nx_graph = nx.relabel_nodes(nx_graph, dict_ind2name)
31 if graphviz:
32 pos = graphviz_layout(nx_graph, prog="dot")
33 else:
34 pos = nx.spring_layout(nx_graph, k=0.5, scale=2)
35 nx.draw(nx_graph, pos=pos, ax=ax, arrows=True, with_labels=True,
36 node_color="white")
37 if ax is None:
38 plt.title(title)
39 plt.axis("off")
40 plt.savefig(title + ".pdf", format="pdf")
41 plt.savefig(title + ".svg", format="svg")
42 if show:
43 plt.show()