Statistics
| Branch: | Revision:

iof-tools / networkxMiCe / networkx-master / networkx / drawing / nx_pylab.py @ 5cef0f13

History | View | Annotate | Download (36.8 KB)

1
#    Copyright (C) 2004-2019 by
2
#    Aric Hagberg <hagberg@lanl.gov>
3
#    Dan Schult <dschult@colgate.edu>
4
#    Pieter Swart <swart@lanl.gov>
5
#    All rights reserved.
6
#    BSD license.
7
#
8
# Author: Aric Hagberg (hagberg@lanl.gov)
9
"""
10
**********
11
Matplotlib
12
**********
13

14
Draw networks with matplotlib.
15

16
See Also
17
--------
18

19
matplotlib:     http://matplotlib.org/
20

21
pygraphviz:     http://pygraphviz.github.io/
22

23
"""
24
from numbers import Number
25
import networkx as nx
26
from networkx.utils import is_string_like
27
from networkx.drawing.layout import shell_layout, \
28
    circular_layout, kamada_kawai_layout, spectral_layout, \
29
    spring_layout, random_layout, planar_layout
30

    
31
__all__ = ['draw',
32
           'draw_networkx',
33
           'draw_networkx_nodes',
34
           'draw_networkx_edges',
35
           'draw_networkx_labels',
36
           'draw_networkx_edge_labels',
37
           'draw_circular',
38
           'draw_kamada_kawai',
39
           'draw_random',
40
           'draw_spectral',
41
           'draw_spring',
42
           'draw_planar',
43
           'draw_shell']
44

    
45

    
46
def draw(G, pos=None, ax=None, **kwds):
47
    """Draw the graph G with Matplotlib.
48

49
    Draw the graph as a simple representation with no node
50
    labels or edge labels and using the full Matplotlib figure area
51
    and no axis labels by default.  See draw_networkx() for more
52
    full-featured drawing that allows title, axis labels etc.
53

54
    Parameters
55
    ----------
56
    G : graph
57
       A networkx graph
58

59
    pos : dictionary, optional
60
       A dictionary with nodes as keys and positions as values.
61
       If not specified a spring layout positioning will be computed.
62
       See :py:mod:`networkx.drawing.layout` for functions that
63
       compute node positions.
64

65
    ax : Matplotlib Axes object, optional
66
       Draw the graph in specified Matplotlib axes.
67

68
    kwds : optional keywords
69
       See networkx.draw_networkx() for a description of optional keywords.
70

71
    Examples
72
    --------
73
    >>> G = nx.dodecahedral_graph()
74
    >>> nx.draw(G)
75
    >>> nx.draw(G, pos=nx.spring_layout(G))  # use spring layout
76

77
    See Also
78
    --------
79
    draw_networkx()
80
    draw_networkx_nodes()
81
    draw_networkx_edges()
82
    draw_networkx_labels()
83
    draw_networkx_edge_labels()
84

85
    Notes
86
    -----
87
    This function has the same name as pylab.draw and pyplot.draw
88
    so beware when using
89

90
    >>> from networkx import *
91

92
    since you might overwrite the pylab.draw function.
93

94
    With pyplot use
95

96
    >>> import matplotlib.pyplot as plt
97
    >>> import networkx as nx
98
    >>> G = nx.dodecahedral_graph()
99
    >>> nx.draw(G)  # networkx draw()
100
    >>> plt.draw()  # pyplot draw()
101

102
    Also see the NetworkX drawing examples at
103
    https://networkx.github.io/documentation/latest/auto_examples/index.html
104
    """
105
    try:
106
        import matplotlib.pyplot as plt
107
    except ImportError:
108
        raise ImportError("Matplotlib required for draw()")
109
    except RuntimeError:
110
        print("Matplotlib unable to open display")
111
        raise
112

    
113
    if ax is None:
114
        cf = plt.gcf()
115
    else:
116
        cf = ax.get_figure()
117
    cf.set_facecolor('w')
118
    if ax is None:
119
        if cf._axstack() is None:
120
            ax = cf.add_axes((0, 0, 1, 1))
121
        else:
122
            ax = cf.gca()
123

    
124
    if 'with_labels' not in kwds:
125
        kwds['with_labels'] = 'labels' in kwds
126

    
127
    try:
128
        draw_networkx(G, pos=pos, ax=ax, **kwds)
129
        ax.set_axis_off()
130
        plt.draw_if_interactive()
131
    except:
132
        raise
133
    return
134

    
135

    
136
def draw_networkx(G, pos=None, arrows=True, with_labels=True, **kwds):
137
    """Draw the graph G using Matplotlib.
138

139
    Draw the graph with Matplotlib with options for node positions,
140
    labeling, titles, and many other drawing features.
141
    See draw() for simple drawing without labels or axes.
142

143
    Parameters
144
    ----------
145
    G : graph
146
       A networkx graph
147

148
    pos : dictionary, optional
149
       A dictionary with nodes as keys and positions as values.
150
       If not specified a spring layout positioning will be computed.
151
       See :py:mod:`networkx.drawing.layout` for functions that
152
       compute node positions.
153

154
    arrows : bool, optional (default=True)
155
       For directed graphs, if True draw arrowheads.
156
       Note: Arrows will be the same color as edges.
157

158
    arrowstyle : str, optional (default='-|>')
159
        For directed graphs, choose the style of the arrowsheads.
160
        See :py:class: `matplotlib.patches.ArrowStyle` for more
161
        options.
162

163
    arrowsize : int, optional (default=10)
164
       For directed graphs, choose the size of the arrow head head's length and
165
       width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
166
       `mutation_scale` for more info.
167

168
    with_labels :  bool, optional (default=True)
169
       Set to True to draw labels on the nodes.
170

171
    ax : Matplotlib Axes object, optional
172
       Draw the graph in the specified Matplotlib axes.
173

174
    nodelist : list, optional (default G.nodes())
175
       Draw only specified nodes
176

177
    edgelist : list, optional (default=G.edges())
178
       Draw only specified edges
179

180
    node_size : scalar or array, optional (default=300)
181
       Size of nodes.  If an array is specified it must be the
182
       same length as nodelist.
183

184
    node_color : color or array of colors (default='#1f78b4')
185
       Node color. Can be a single color or a sequence of colors with the same
186
       length as nodelist. Color can be string, or rgb (or rgba) tuple of
187
       floats from 0-1. If numeric values are specified they will be
188
       mapped to colors using the cmap and vmin,vmax parameters. See
189
       matplotlib.scatter for more details.
190

191
    node_shape :  string, optional (default='o')
192
       The shape of the node.  Specification is as matplotlib.scatter
193
       marker, one of 'so^>v<dph8'.
194

195
    alpha : float, optional (default=None)
196
       The node and edge transparency
197

198
    cmap : Matplotlib colormap, optional (default=None)
199
       Colormap for mapping intensities of nodes
200

201
    vmin,vmax : float, optional (default=None)
202
       Minimum and maximum for node colormap scaling
203

204
    linewidths : [None | scalar | sequence]
205
       Line width of symbol border (default =1.0)
206

207
    width : float, optional (default=1.0)
208
       Line width of edges
209

210
    edge_color : color or array of colors (default='k')
211
       Edge color. Can be a single color or a sequence of colors with the same
212
       length as edgelist. Color can be string, or rgb (or rgba) tuple of
213
       floats from 0-1. If numeric values are specified they will be
214
       mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
215

216
    edge_cmap : Matplotlib colormap, optional (default=None)
217
       Colormap for mapping intensities of edges
218

219
    edge_vmin,edge_vmax : floats, optional (default=None)
220
       Minimum and maximum for edge colormap scaling
221

222
    style : string, optional (default='solid')
223
       Edge line style (solid|dashed|dotted,dashdot)
224

225
    labels : dictionary, optional (default=None)
226
       Node labels in a dictionary keyed by node of text labels
227

228
    font_size : int, optional (default=12)
229
       Font size for text labels
230

231
    font_color : string, optional (default='k' black)
232
       Font color string
233

234
    font_weight : string, optional (default='normal')
235
       Font weight
236

237
    font_family : string, optional (default='sans-serif')
238
       Font family
239

240
    label : string, optional
241
        Label for graph legend
242

243
    Notes
244
    -----
245
    For directed graphs, arrows  are drawn at the head end.  Arrows can be
246
    turned off with keyword arrows=False.
247

248
    Examples
249
    --------
250
    >>> G = nx.dodecahedral_graph()
251
    >>> nx.draw(G)
252
    >>> nx.draw(G, pos=nx.spring_layout(G))  # use spring layout
253

254
    >>> import matplotlib.pyplot as plt
255
    >>> limits = plt.axis('off')  # turn of axis
256

257
    Also see the NetworkX drawing examples at
258
    https://networkx.github.io/documentation/latest/auto_examples/index.html
259

260
    See Also
261
    --------
262
    draw()
263
    draw_networkx_nodes()
264
    draw_networkx_edges()
265
    draw_networkx_labels()
266
    draw_networkx_edge_labels()
267
    """
268
    try:
269
        import matplotlib.pyplot as plt
270
    except ImportError:
271
        raise ImportError("Matplotlib required for draw()")
272
    except RuntimeError:
273
        print("Matplotlib unable to open display")
274
        raise
275

    
276
    if pos is None:
277
        pos = nx.drawing.spring_layout(G)  # default to spring layout
278

    
279
    node_collection = draw_networkx_nodes(G, pos, **kwds)
280
    edge_collection = draw_networkx_edges(G, pos, arrows=arrows, **kwds)
281
    if with_labels:
282
        draw_networkx_labels(G, pos, **kwds)
283
    plt.draw_if_interactive()
284

    
285

    
286
def draw_networkx_nodes(G, pos,
287
                        nodelist=None,
288
                        node_size=300,
289
                        node_color='#1f78b4',
290
                        node_shape='o',
291
                        alpha=None,
292
                        cmap=None,
293
                        vmin=None,
294
                        vmax=None,
295
                        ax=None,
296
                        linewidths=None,
297
                        edgecolors=None,
298
                        label=None,
299
                        **kwds):
300
    """Draw the nodes of the graph G.
301

302
    This draws only the nodes of the graph G.
303

304
    Parameters
305
    ----------
306
    G : graph
307
       A networkx graph
308

309
    pos : dictionary
310
       A dictionary with nodes as keys and positions as values.
311
       Positions should be sequences of length 2.
312

313
    ax : Matplotlib Axes object, optional
314
       Draw the graph in the specified Matplotlib axes.
315

316
    nodelist : list, optional
317
       Draw only specified nodes (default G.nodes())
318

319
    node_size : scalar or array
320
       Size of nodes (default=300).  If an array is specified it must be the
321
       same length as nodelist.
322

323
    node_color : color or array of colors (default='#1f78b4')
324
       Node color. Can be a single color or a sequence of colors with the same
325
       length as nodelist. Color can be string, or rgb (or rgba) tuple of
326
       floats from 0-1. If numeric values are specified they will be
327
       mapped to colors using the cmap and vmin,vmax parameters. See
328
       matplotlib.scatter for more details.
329

330
    node_shape :  string
331
       The shape of the node.  Specification is as matplotlib.scatter
332
       marker, one of 'so^>v<dph8' (default='o').
333

334
    alpha : float or array of floats
335
       The node transparency.  This can be a single alpha value (default=None),
336
       in which case it will be applied to all the nodes of color. Otherwise,
337
       if it is an array, the elements of alpha will be applied to the colors
338
       in order (cycling through alpha multiple times if necessary).
339

340
    cmap : Matplotlib colormap
341
       Colormap for mapping intensities of nodes (default=None)
342

343
    vmin,vmax : floats
344
       Minimum and maximum for node colormap scaling (default=None)
345

346
    linewidths : [None | scalar | sequence]
347
       Line width of symbol border (default =1.0)
348

349
    edgecolors : [None | scalar | sequence]
350
       Colors of node borders (default = node_color)
351

352
    label : [None| string]
353
       Label for legend
354

355
    Returns
356
    -------
357
    matplotlib.collections.PathCollection
358
        `PathCollection` of the nodes.
359

360
    Examples
361
    --------
362
    >>> G = nx.dodecahedral_graph()
363
    >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
364

365
    Also see the NetworkX drawing examples at
366
    https://networkx.github.io/documentation/latest/auto_examples/index.html
367

368
    See Also
369
    --------
370
    draw()
371
    draw_networkx()
372
    draw_networkx_edges()
373
    draw_networkx_labels()
374
    draw_networkx_edge_labels()
375
    """
376
    from collections.abc import Iterable
377
    try:
378
        import matplotlib.pyplot as plt
379
        import numpy as np
380
    except ImportError:
381
        raise ImportError("Matplotlib required for draw()")
382
    except RuntimeError:
383
        print("Matplotlib unable to open display")
384
        raise
385

    
386
    if ax is None:
387
        ax = plt.gca()
388

    
389
    if nodelist is None:
390
        nodelist = list(G)
391

    
392
    if not nodelist or len(nodelist) == 0:  # empty nodelist, no drawing
393
        return None
394

    
395
    try:
396
        xy = np.asarray([pos[v] for v in nodelist])
397
    except KeyError as e:
398
        raise nx.NetworkXError('Node %s has no position.' % e)
399
    except ValueError:
400
        raise nx.NetworkXError('Bad value in node positions.')
401

    
402
    if isinstance(alpha, Iterable):
403
        node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
404
        alpha = None
405

    
406
    node_collection = ax.scatter(xy[:, 0], xy[:, 1],
407
                                 s=node_size,
408
                                 c=node_color,
409
                                 marker=node_shape,
410
                                 cmap=cmap,
411
                                 vmin=vmin,
412
                                 vmax=vmax,
413
                                 alpha=alpha,
414
                                 linewidths=linewidths,
415
                                 edgecolors=edgecolors,
416
                                 label=label)
417
    ax.tick_params(
418
        axis='both',
419
        which='both',
420
        bottom=False,
421
        left=False,
422
        labelbottom=False,
423
        labelleft=False)
424

    
425
    node_collection.set_zorder(2)
426
    return node_collection
427

    
428

    
429
def draw_networkx_edges(G, pos,
430
                        edgelist=None,
431
                        width=1.0,
432
                        edge_color='k',
433
                        style='solid',
434
                        alpha=None,
435
                        arrowstyle='-|>',
436
                        arrowsize=10,
437
                        edge_cmap=None,
438
                        edge_vmin=None,
439
                        edge_vmax=None,
440
                        ax=None,
441
                        arrows=True,
442
                        label=None,
443
                        node_size=300,
444
                        nodelist=None,
445
                        node_shape="o",
446
                        connectionstyle=None,
447
                        **kwds):
448
    """Draw the edges of the graph G.
449

450
    This draws only the edges of the graph G.
451

452
    Parameters
453
    ----------
454
    G : graph
455
       A networkx graph
456

457
    pos : dictionary
458
       A dictionary with nodes as keys and positions as values.
459
       Positions should be sequences of length 2.
460

461
    edgelist : collection of edge tuples
462
       Draw only specified edges(default=G.edges())
463

464
    width : float, or array of floats
465
       Line width of edges (default=1.0)
466

467
    edge_color : color or array of colors (default='k')
468
       Edge color. Can be a single color or a sequence of colors with the same
469
       length as edgelist. Color can be string, or rgb (or rgba) tuple of
470
       floats from 0-1. If numeric values are specified they will be
471
       mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
472

473
    style : string
474
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)
475

476
    alpha : float
477
       The edge transparency (default=None)
478

479
    edge_ cmap : Matplotlib colormap
480
       Colormap for mapping intensities of edges (default=None)
481

482
    edge_vmin,edge_vmax : floats
483
       Minimum and maximum for edge colormap scaling (default=None)
484

485
    ax : Matplotlib Axes object, optional
486
       Draw the graph in the specified Matplotlib axes.
487

488
    arrows : bool, optional (default=True)
489
       For directed graphs, if True draw arrowheads.
490
       Note: Arrows will be the same color as edges.
491

492
    arrowstyle : str, optional (default='-|>')
493
       For directed graphs, choose the style of the arrow heads.
494
       See :py:class: `matplotlib.patches.ArrowStyle` for more
495
       options.
496

497
    arrowsize : int, optional (default=10)
498
       For directed graphs, choose the size of the arrow head head's length and
499
       width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
500
       `mutation_scale` for more info.
501

502
    connectionstyle : str, optional (default=None)
503
       Pass the connectionstyle parameter to create curved arc of rounding
504
       radius rad. For example, connectionstyle='arc3,rad=0.2'.
505
       See :py:class: `matplotlib.patches.ConnectionStyle` and
506
       :py:class: `matplotlib.patches.FancyArrowPatch` for more info.
507

508
    label : [None| string]
509
       Label for legend
510

511
    Returns
512
    -------
513
    matplotlib.collection.LineCollection
514
        `LineCollection` of the edges
515

516
    list of matplotlib.patches.FancyArrowPatch
517
        `FancyArrowPatch` instances of the directed edges
518

519
    Depending whether the drawing includes arrows or not.
520

521
    Notes
522
    -----
523
    For directed graphs, arrows are drawn at the head end.  Arrows can be
524
    turned off with keyword arrows=False. Be sure to include `node_size` as a
525
    keyword argument; arrows are drawn considering the size of nodes.
526

527
    Examples
528
    --------
529
    >>> G = nx.dodecahedral_graph()
530
    >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
531

532
    >>> G = nx.DiGraph()
533
    >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
534
    >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
535
    >>> alphas = [0.3, 0.4, 0.5]
536
    >>> for i, arc in enumerate(arcs):  # change alpha values of arcs
537
    ...     arc.set_alpha(alphas[i])
538

539
    Also see the NetworkX drawing examples at
540
    https://networkx.github.io/documentation/latest/auto_examples/index.html
541

542
    See Also
543
    --------
544
    draw()
545
    draw_networkx()
546
    draw_networkx_nodes()
547
    draw_networkx_labels()
548
    draw_networkx_edge_labels()
549
    """
550
    try:
551
        import matplotlib
552
        import matplotlib.pyplot as plt
553
        from matplotlib.colors import colorConverter, Colormap, Normalize
554
        from matplotlib.collections import LineCollection
555
        from matplotlib.patches import FancyArrowPatch
556
        import numpy as np
557
    except ImportError:
558
        raise ImportError("Matplotlib required for draw()")
559
    except RuntimeError:
560
        print("Matplotlib unable to open display")
561
        raise
562

    
563
    if ax is None:
564
        ax = plt.gca()
565

    
566
    if edgelist is None:
567
        edgelist = list(G.edges())
568

    
569
    if not edgelist or len(edgelist) == 0:  # no edges!
570
        return None
571

    
572
    if nodelist is None:
573
        nodelist = list(G.nodes())
574

    
575
    # FancyArrowPatch handles color=None different from LineCollection
576
    if edge_color is None:
577
        edge_color = 'k'
578

    
579
    # set edge positions
580
    edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
581

    
582
    # Check if edge_color is an array of floats and map to edge_cmap.
583
    # This is the only case handled differently from matplotlib
584
    if np.iterable(edge_color) and (len(edge_color) == len(edge_pos)) \
585
        and np.alltrue([isinstance(c,Number) for c in edge_color]):
586
            if edge_cmap is not None:
587
                assert(isinstance(edge_cmap, Colormap))
588
            else:
589
                edge_cmap = plt.get_cmap()
590
            if edge_vmin is None:
591
                edge_vmin = min(edge_color)
592
            if edge_vmax is None:
593
                edge_vmax = max(edge_color)
594
            color_normal = Normalize(vmin=edge_vmin, vmax=edge_vmax)
595
            edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
596

    
597
    if (not G.is_directed() or not arrows):
598
        edge_collection = LineCollection(edge_pos,
599
                                         colors=edge_color,
600
                                         linewidths=width,
601
                                         antialiaseds=(1,),
602
                                         linestyle=style,
603
                                         transOffset=ax.transData,
604
                                         alpha=alpha
605
                                         )
606

    
607
        edge_collection.set_zorder(1)  # edges go behind nodes
608
        edge_collection.set_label(label)
609
        ax.add_collection(edge_collection)
610

    
611
        return edge_collection
612

    
613
    arrow_collection = None
614

    
615
    if G.is_directed() and arrows:
616
        # Note: Waiting for someone to implement arrow to intersection with
617
        # marker.  Meanwhile, this works well for polygons with more than 4
618
        # sides and circle.
619

    
620
        def to_marker_edge(marker_size, marker):
621
            if marker in "s^>v<d":  # `large` markers need extra space
622
                return np.sqrt(2 * marker_size) / 2
623
            else:
624
                return np.sqrt(marker_size) / 2
625

    
626
        # Draw arrows with `matplotlib.patches.FancyarrowPatch`
627
        arrow_collection = []
628
        mutation_scale = arrowsize  # scale factor of arrow head
629

    
630
        # FancyArrowPatch doesn't handle color strings
631
        arrow_colors = colorConverter.to_rgba_array(edge_color,alpha)
632
        for i, (src, dst) in enumerate(edge_pos):
633
            x1, y1 = src
634
            x2, y2 = dst
635
            shrink_source = 0  # space from source to tail
636
            shrink_target = 0  # space from  head to target
637
            if np.iterable(node_size):  # many node sizes
638
                src_node, dst_node = edgelist[i][:2]
639
                index_node = nodelist.index(dst_node)
640
                marker_size = node_size[index_node]
641
                shrink_target = to_marker_edge(marker_size, node_shape)
642
            else:
643
                shrink_target = to_marker_edge(node_size, node_shape)
644

    
645
            if np.iterable(arrow_colors):
646
                if len(arrow_colors) == len(edge_pos):
647
                    arrow_color = arrow_colors[i]
648
                elif len(arrow_colors)==1:
649
                    arrow_color = arrow_colors[0]
650
                else: # Cycle through colors
651
                    arrow_color =  arrow_colors[i%len(arrow_colors)]
652
            else:
653
                arrow_color = edge_color
654

    
655
            if np.iterable(width):
656
                if len(width) == len(edge_pos):
657
                    line_width = width[i]
658
                else:
659
                    line_width = width[i%len(width)]
660
            else:
661
                line_width = width
662

    
663
            arrow = FancyArrowPatch((x1, y1), (x2, y2),
664
                                    arrowstyle=arrowstyle,
665
                                    shrinkA=shrink_source,
666
                                    shrinkB=shrink_target,
667
                                    mutation_scale=mutation_scale,
668
                                    color=arrow_color,
669
                                    linewidth=line_width,
670
                                    connectionstyle=connectionstyle,
671
                                    zorder=1)  # arrows go behind nodes
672

    
673
            # There seems to be a bug in matplotlib to make collections of
674
            # FancyArrowPatch instances. Until fixed, the patches are added
675
            # individually to the axes instance.
676
            arrow_collection.append(arrow)
677
            ax.add_patch(arrow)
678

    
679
    # update view
680
    minx = np.amin(np.ravel(edge_pos[:, :, 0]))
681
    maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
682
    miny = np.amin(np.ravel(edge_pos[:, :, 1]))
683
    maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
684

    
685
    w = maxx - minx
686
    h = maxy - miny
687
    padx,  pady = 0.05 * w, 0.05 * h
688
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
689
    ax.update_datalim(corners)
690
    ax.autoscale_view()
691

    
692
    ax.tick_params(
693
        axis='both',
694
        which='both',
695
        bottom=False,
696
        left=False,
697
        labelbottom=False,
698
        labelleft=False)
699

    
700
    return arrow_collection
701

    
702

    
703
def draw_networkx_labels(G, pos,
704
                         labels=None,
705
                         font_size=12,
706
                         font_color='k',
707
                         font_family='sans-serif',
708
                         font_weight='normal',
709
                         alpha=None,
710
                         bbox=None,
711
                         ax=None,
712
                         **kwds):
713
    """Draw node labels on the graph G.
714

715
    Parameters
716
    ----------
717
    G : graph
718
       A networkx graph
719

720
    pos : dictionary
721
       A dictionary with nodes as keys and positions as values.
722
       Positions should be sequences of length 2.
723

724
    labels : dictionary, optional (default=None)
725
       Node labels in a dictionary keyed by node of text labels
726
       Node-keys in labels should appear as keys in `pos`.
727
       If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
728

729
    font_size : int
730
       Font size for text labels (default=12)
731

732
    font_color : string
733
       Font color string (default='k' black)
734

735
    font_family : string
736
       Font family (default='sans-serif')
737

738
    font_weight : string
739
       Font weight (default='normal')
740

741
    alpha : float or None
742
       The text transparency (default=None)
743

744
    ax : Matplotlib Axes object, optional
745
       Draw the graph in the specified Matplotlib axes.
746

747
    Returns
748
    -------
749
    dict
750
        `dict` of labels keyed on the nodes
751

752
    Examples
753
    --------
754
    >>> G = nx.dodecahedral_graph()
755
    >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
756

757
    Also see the NetworkX drawing examples at
758
    https://networkx.github.io/documentation/latest/auto_examples/index.html
759

760
    See Also
761
    --------
762
    draw()
763
    draw_networkx()
764
    draw_networkx_nodes()
765
    draw_networkx_edges()
766
    draw_networkx_edge_labels()
767
    """
768
    try:
769
        import matplotlib.pyplot as plt
770
    except ImportError:
771
        raise ImportError("Matplotlib required for draw()")
772
    except RuntimeError:
773
        print("Matplotlib unable to open display")
774
        raise
775

    
776
    if ax is None:
777
        ax = plt.gca()
778

    
779
    if labels is None:
780
        labels = dict((n, n) for n in G.nodes())
781

    
782
    # set optional alignment
783
    horizontalalignment = kwds.get('horizontalalignment', 'center')
784
    verticalalignment = kwds.get('verticalalignment', 'center')
785

    
786
    text_items = {}  # there is no text collection so we'll fake one
787
    for n, label in labels.items():
788
        (x, y) = pos[n]
789
        if not is_string_like(label):
790
            label = str(label)  # this makes "1" and 1 labeled the same
791
        t = ax.text(x, y,
792
                    label,
793
                    size=font_size,
794
                    color=font_color,
795
                    family=font_family,
796
                    weight=font_weight,
797
                    alpha=alpha,
798
                    horizontalalignment=horizontalalignment,
799
                    verticalalignment=verticalalignment,
800
                    transform=ax.transData,
801
                    bbox=bbox,
802
                    clip_on=True,
803
                    )
804
        text_items[n] = t
805

    
806
    ax.tick_params(
807
        axis='both',
808
        which='both',
809
        bottom=False,
810
        left=False,
811
        labelbottom=False,
812
        labelleft=False)
813

    
814
    return text_items
815

    
816

    
817
def draw_networkx_edge_labels(G, pos,
818
                              edge_labels=None,
819
                              label_pos=0.5,
820
                              font_size=10,
821
                              font_color='k',
822
                              font_family='sans-serif',
823
                              font_weight='normal',
824
                              alpha=None,
825
                              bbox=None,
826
                              ax=None,
827
                              rotate=True,
828
                              **kwds):
829
    """Draw edge labels.
830

831
    Parameters
832
    ----------
833
    G : graph
834
       A networkx graph
835

836
    pos : dictionary
837
       A dictionary with nodes as keys and positions as values.
838
       Positions should be sequences of length 2.
839

840
    ax : Matplotlib Axes object, optional
841
       Draw the graph in the specified Matplotlib axes.
842

843
    alpha : float or None
844
       The text transparency (default=None)
845

846
    edge_labels : dictionary
847
       Edge labels in a dictionary keyed by edge two-tuple of text
848
       labels (default=None). Only labels for the keys in the dictionary
849
       are drawn.
850

851
    label_pos : float
852
       Position of edge label along edge (0=head, 0.5=center, 1=tail)
853

854
    font_size : int
855
       Font size for text labels (default=12)
856

857
    font_color : string
858
       Font color string (default='k' black)
859

860
    font_weight : string
861
       Font weight (default='normal')
862

863
    font_family : string
864
       Font family (default='sans-serif')
865

866
    bbox : Matplotlib bbox
867
       Specify text box shape and colors.
868

869
    clip_on : bool
870
       Turn on clipping at axis boundaries (default=True)
871

872
    Returns
873
    -------
874
    dict
875
        `dict` of labels keyed on the edges
876

877
    Examples
878
    --------
879
    >>> G = nx.dodecahedral_graph()
880
    >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
881

882
    Also see the NetworkX drawing examples at
883
    https://networkx.github.io/documentation/latest/auto_examples/index.html
884

885
    See Also
886
    --------
887
    draw()
888
    draw_networkx()
889
    draw_networkx_nodes()
890
    draw_networkx_edges()
891
    draw_networkx_labels()
892
    """
893
    try:
894
        import matplotlib.pyplot as plt
895
        import numpy as np
896
    except ImportError:
897
        raise ImportError("Matplotlib required for draw()")
898
    except RuntimeError:
899
        print("Matplotlib unable to open display")
900
        raise
901

    
902
    if ax is None:
903
        ax = plt.gca()
904
    if edge_labels is None:
905
        labels = {(u, v): d for u, v, d in G.edges(data=True)}
906
    else:
907
        labels = edge_labels
908
    text_items = {}
909
    for (n1, n2), label in labels.items():
910
        (x1, y1) = pos[n1]
911
        (x2, y2) = pos[n2]
912
        (x, y) = (x1 * label_pos + x2 * (1.0 - label_pos),
913
                  y1 * label_pos + y2 * (1.0 - label_pos))
914

    
915
        if rotate:
916
            # in degrees
917
            angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
918
            # make label orientation "right-side-up"
919
            if angle > 90:
920
                angle -= 180
921
            if angle < - 90:
922
                angle += 180
923
            # transform data coordinate angle to screen coordinate angle
924
            xy = np.array((x, y))
925
            trans_angle = ax.transData.transform_angles(np.array((angle,)),
926
                                                        xy.reshape((1, 2)))[0]
927
        else:
928
            trans_angle = 0.0
929
        # use default box of white with white border
930
        if bbox is None:
931
            bbox = dict(boxstyle='round',
932
                        ec=(1.0, 1.0, 1.0),
933
                        fc=(1.0, 1.0, 1.0),
934
                        )
935
        if not is_string_like(label):
936
            label = str(label)  # this makes "1" and 1 labeled the same
937

    
938
        # set optional alignment
939
        horizontalalignment = kwds.get('horizontalalignment', 'center')
940
        verticalalignment = kwds.get('verticalalignment', 'center')
941

    
942
        t = ax.text(x, y,
943
                    label,
944
                    size=font_size,
945
                    color=font_color,
946
                    family=font_family,
947
                    weight=font_weight,
948
                    alpha=alpha,
949
                    horizontalalignment=horizontalalignment,
950
                    verticalalignment=verticalalignment,
951
                    rotation=trans_angle,
952
                    transform=ax.transData,
953
                    bbox=bbox,
954
                    zorder=1,
955
                    clip_on=True,
956
                    )
957
        text_items[(n1, n2)] = t
958

    
959
    ax.tick_params(
960
        axis='both',
961
        which='both',
962
        bottom=False,
963
        left=False,
964
        labelbottom=False,
965
        labelleft=False)
966

    
967
    return text_items
968

    
969

    
970
def draw_circular(G, **kwargs):
971
    """Draw the graph G with a circular layout.
972

973
    Parameters
974
    ----------
975
    G : graph
976
       A networkx graph
977

978
    kwargs : optional keywords
979
       See networkx.draw_networkx() for a description of optional keywords,
980
       with the exception of the pos parameter which is not used by this
981
       function.
982
    """
983
    draw(G, circular_layout(G), **kwargs)
984

    
985

    
986
def draw_kamada_kawai(G, **kwargs):
987
    """Draw the graph G with a Kamada-Kawai force-directed layout.
988

989
    Parameters
990
    ----------
991
    G : graph
992
       A networkx graph
993

994
    kwargs : optional keywords
995
       See networkx.draw_networkx() for a description of optional keywords,
996
       with the exception of the pos parameter which is not used by this
997
       function.
998
    """
999
    draw(G, kamada_kawai_layout(G), **kwargs)
1000

    
1001

    
1002
def draw_random(G, **kwargs):
1003
    """Draw the graph G with a random layout.
1004

1005
    Parameters
1006
    ----------
1007
    G : graph
1008
       A networkx graph
1009

1010
    kwargs : optional keywords
1011
       See networkx.draw_networkx() for a description of optional keywords,
1012
       with the exception of the pos parameter which is not used by this
1013
       function.
1014
    """
1015
    draw(G, random_layout(G), **kwargs)
1016

    
1017

    
1018
def draw_spectral(G, **kwargs):
1019
    """Draw the graph G with a spectral 2D layout.
1020

1021
    Using the unnormalized Laplacion, the layout shows possible clusters of
1022
    nodes which are an approximation of the ratio cut. The positions are the
1023
    entries of the second and third eigenvectors corresponding to the
1024
    ascending eigenvalues starting from the second one.
1025

1026
    Parameters
1027
    ----------
1028
    G : graph
1029
       A networkx graph
1030

1031
    kwargs : optional keywords
1032
       See networkx.draw_networkx() for a description of optional keywords,
1033
       with the exception of the pos parameter which is not used by this
1034
       function.
1035
    """
1036
    draw(G, spectral_layout(G), **kwargs)
1037

    
1038

    
1039
def draw_spring(G, **kwargs):
1040
    """Draw the graph G with a spring layout.
1041

1042
    Parameters
1043
    ----------
1044
    G : graph
1045
       A networkx graph
1046

1047
    kwargs : optional keywords
1048
       See networkx.draw_networkx() for a description of optional keywords,
1049
       with the exception of the pos parameter which is not used by this
1050
       function.
1051
    """
1052
    draw(G, spring_layout(G), **kwargs)
1053

    
1054

    
1055
def draw_shell(G, **kwargs):
1056
    """Draw networkx graph with shell layout.
1057

1058
    Parameters
1059
    ----------
1060
    G : graph
1061
       A networkx graph
1062

1063
    kwargs : optional keywords
1064
       See networkx.draw_networkx() for a description of optional keywords,
1065
       with the exception of the pos parameter which is not used by this
1066
       function.
1067
    """
1068
    nlist = kwargs.get('nlist', None)
1069
    if nlist is not None:
1070
        del(kwargs['nlist'])
1071
    draw(G, shell_layout(G, nlist=nlist), **kwargs)
1072

    
1073

    
1074
def draw_planar(G, **kwargs):
1075
    """Draw a planar networkx graph with planar layout.
1076

1077
    Parameters
1078
    ----------
1079
    G : graph
1080
       A planar networkx graph
1081

1082
    kwargs : optional keywords
1083
       See networkx.draw_networkx() for a description of optional keywords,
1084
       with the exception of the pos parameter which is not used by this
1085
       function.
1086
    """
1087
    draw(G, planar_layout(G), **kwargs)
1088

    
1089

    
1090
def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
1091
    """Apply an alpha (or list of alphas) to the colors provided.
1092

1093
    Parameters
1094
    ----------
1095

1096
    colors : color string, or array of floats
1097
       Color of element. Can be a single color format string (default='r'),
1098
       or a  sequence of colors with the same length as nodelist.
1099
       If numeric values are specified they will be mapped to
1100
       colors using the cmap and vmin,vmax parameters.  See
1101
       matplotlib.scatter for more details.
1102

1103
    alpha : float or array of floats
1104
       Alpha values for elements. This can be a single alpha value, in
1105
       which case it will be applied to all the elements of color. Otherwise,
1106
       if it is an array, the elements of alpha will be applied to the colors
1107
       in order (cycling through alpha multiple times if necessary).
1108

1109
    elem_list : array of networkx objects
1110
       The list of elements which are being colored. These could be nodes,
1111
       edges or labels.
1112

1113
    cmap : matplotlib colormap
1114
       Color map for use if colors is a list of floats corresponding to points
1115
       on a color mapping.
1116

1117
    vmin, vmax : float
1118
       Minimum and maximum values for normalizing colors if a color mapping is
1119
       used.
1120

1121
    Returns
1122
    -------
1123

1124
    rgba_colors : numpy ndarray
1125
        Array containing RGBA format values for each of the node colours.
1126

1127
    """
1128
    from itertools import islice, cycle
1129

    
1130
    try:
1131
        import numpy as np
1132
        from matplotlib.colors import colorConverter
1133
        import matplotlib.cm as cm
1134
    except ImportError:
1135
        raise ImportError("Matplotlib required for draw()")
1136

    
1137
    # If we have been provided with a list of numbers as long as elem_list,
1138
    # apply the color mapping.
1139
    if len(colors) == len(elem_list) and isinstance(colors[0], Number):
1140
        mapper = cm.ScalarMappable(cmap=cmap)
1141
        mapper.set_clim(vmin, vmax)
1142
        rgba_colors = mapper.to_rgba(colors)
1143
    # Otherwise, convert colors to matplotlib's RGB using the colorConverter
1144
    # object.  These are converted to numpy ndarrays to be consistent with the
1145
    # to_rgba method of ScalarMappable.
1146
    else:
1147
        try:
1148
            rgba_colors = np.array([colorConverter.to_rgba(colors)])
1149
        except ValueError:
1150
            rgba_colors = np.array([colorConverter.to_rgba(color)
1151
                                    for color in colors])
1152
    # Set the final column of the rgba_colors to have the relevant alpha values
1153
    try:
1154
        # If alpha is longer than the number of colors, resize to the number of
1155
        # elements.  Also, if rgba_colors.size (the number of elements of
1156
        # rgba_colors) is the same as the number of elements, resize the array,
1157
        # to avoid it being interpreted as a colormap by scatter()
1158
        if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
1159
            rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
1160
            rgba_colors[1:, 0] = rgba_colors[0, 0]
1161
            rgba_colors[1:, 1] = rgba_colors[0, 1]
1162
            rgba_colors[1:, 2] = rgba_colors[0, 2]
1163
        rgba_colors[:,  3] = list(islice(cycle(alpha), len(rgba_colors)))
1164
    except TypeError:
1165
        rgba_colors[:, -1] = alpha
1166
    return rgba_colors
1167

    
1168
# fixture for nose tests
1169

    
1170

    
1171
def setup_module(module):
1172
    from nose import SkipTest
1173
    try:
1174
        import matplotlib as mpl
1175
        mpl.use('PS', warn=False)
1176
        import matplotlib.pyplot as plt
1177
    except:
1178
        raise SkipTest("matplotlib not available")