Coverage for causalspyne/draw_dags.py: 89%
19 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-19 14:58 +0000
1"""
2draw DAG using networkx
3"""
5import matplotlib.pyplot as plt
6import networkx as nx
7from networkx.drawing.nx_agraph import graphviz_layout
10def draw_dags_nx(
11 adj_matrix, dict_ind2name=None, title="dag", ax=None, show=False,
12 graphviz=False
13):
14 """
15 networkx adjacency matrix (i,j) entry refers to edge from i pointing to j,
16 which is opposite to the CausalSpyne convention
17 """
18 plt.close("all")
19 nx_graph = nx.from_numpy_array(adj_matrix.transpose(),
20 create_using=nx.DiGraph)
21 if dict_ind2name:
22 nx_graph = nx.relabel_nodes(nx_graph, dict_ind2name)
23 if graphviz:
24 pos = graphviz_layout(nx_graph, prog="dot")
25 else:
26 pos = nx.spring_layout(nx_graph, k=0.5, scale=2)
27 nx.draw(nx_graph, pos=pos, ax=ax, arrows=True, with_labels=True,
28 node_color="white")
29 if ax is None:
30 plt.title(title)
31 plt.axis("off")
32 plt.savefig(title + ".pdf", format="pdf")
33 plt.savefig(title + ".svg", format="svg")
34 if show:
35 plt.show()