Coverage for causalspyne/draw_dags.py: 89%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-19 14:58 +0000

1""" 

2draw DAG using networkx 

3""" 

4 

5import matplotlib.pyplot as plt 

6import networkx as nx 

7from networkx.drawing.nx_agraph import graphviz_layout 

8 

9 

10def draw_dags_nx( 

11 adj_matrix, dict_ind2name=None, title="dag", ax=None, show=False, 

12 graphviz=False 

13): 

14 """ 

15 networkx adjacency matrix (i,j) entry refers to edge from i pointing to j, 

16 which is opposite to the CausalSpyne convention 

17 """ 

18 plt.close("all") 

19 nx_graph = nx.from_numpy_array(adj_matrix.transpose(), 

20 create_using=nx.DiGraph) 

21 if dict_ind2name: 

22 nx_graph = nx.relabel_nodes(nx_graph, dict_ind2name) 

23 if graphviz: 

24 pos = graphviz_layout(nx_graph, prog="dot") 

25 else: 

26 pos = nx.spring_layout(nx_graph, k=0.5, scale=2) 

27 nx.draw(nx_graph, pos=pos, ax=ax, arrows=True, with_labels=True, 

28 node_color="white") 

29 if ax is None: 

30 plt.title(title) 

31 plt.axis("off") 

32 plt.savefig(title + ".pdf", format="pdf") 

33 plt.savefig(title + ".svg", format="svg") 

34 if show: 

35 plt.show()