Coverage for src/causalspyne/draw_dags.py: 85%

26 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-05-15 16:30 +0000

1""" 

2draw DAG using networkx 

3""" 

4 

5import matplotlib.pyplot as plt 

6import networkx as nx 

7try: 

8 import pygraphviz 

9 has_graphviz = True 

10 from networkx.drawing.nx_agraph import graphviz_layout 

11except ImportError: 

12 graphviz_layout = None 

13 has_graphviz = False 

14 

15 

16 

17def draw_dags_nx( 

18 adj_matrix, dict_ind2name=None, title="dag", ax=None, show=False, 

19 graphviz=False 

20): 

21 """ 

22 networkx adjacency matrix (i,j) entry refers to edge from i pointing to j, 

23 which is opposite to the CausalSpyne convention 

24 """ 

25 graphviz = has_graphviz and graphviz 

26 plt.close("all") 

27 nx_graph = nx.from_numpy_array(adj_matrix.transpose(), 

28 create_using=nx.DiGraph) 

29 if dict_ind2name: 

30 nx_graph = nx.relabel_nodes(nx_graph, dict_ind2name) 

31 if graphviz: 

32 pos = graphviz_layout(nx_graph, prog="dot") 

33 else: 

34 pos = nx.spring_layout(nx_graph, k=0.5, scale=2) 

35 nx.draw(nx_graph, pos=pos, ax=ax, arrows=True, with_labels=True, 

36 node_color="white") 

37 if ax is None: 

38 plt.title(title) 

39 plt.axis("off") 

40 plt.savefig(title + ".pdf", format="pdf") 

41 plt.savefig(title + ".svg", format="svg") 

42 if show: 

43 plt.show()