Statistics
| Branch: | Revision:

iof-tools / networkxMiCe / networkx-master / networkx / readwrite / gexf.py @ 5cef0f13

History | View | Annotate | Download (38.5 KB)

1
# Copyright (C) 2013-2019 by
2
#
3
# Authors: Aric Hagberg <hagberg@lanl.gov>
4
#          Dan Schult <dschult@colgate.edu>
5
#          Pieter Swart <swart@lanl.gov>
6
# All rights reserved.
7
# BSD license.
8
# Based on GraphML NetworkX GraphML reader
9
"""Read and write graphs in GEXF format.
10

11
GEXF (Graph Exchange XML Format) is a language for describing complex
12
network structures, their associated data and dynamics.
13

14
This implementation does not support mixed graphs (directed and
15
undirected edges together).
16

17
Format
18
------
19
GEXF is an XML format.  See https://gephi.org/gexf/format/schema.html for the
20
specification and https://gephi.org/gexf/format/basic.html for examples.
21
"""
22
import itertools
23
import time
24

    
25
import networkx as nx
26
from networkx.utils import open_file, make_str
27
try:
28
    from xml.etree.cElementTree import Element, ElementTree, SubElement, tostring
29
except ImportError:
30
    try:
31
        from xml.etree.ElementTree import Element, ElementTree, SubElement, tostring
32
    except ImportError:
33
        pass
34

    
35
__all__ = ['write_gexf', 'read_gexf', 'relabel_gexf_graph', 'generate_gexf']
36

    
37

    
38
@open_file(1, mode='wb')
39
def write_gexf(G, path, encoding='utf-8', prettyprint=True, version='1.2draft'):
40
    """Write G in GEXF format to path.
41

42
    "GEXF (Graph Exchange XML Format) is a language for describing
43
    complex networks structures, their associated data and dynamics" [1]_.
44

45
    Node attributes are checked according to the version of the GEXF
46
    schemas used for parameters which are not user defined,
47
    e.g. visualization 'viz' [2]_. See example for usage.
48

49
    Parameters
50
    ----------
51
    G : graph
52
       A NetworkX graph
53
    path : file or string
54
       File or file name to write.
55
       File names ending in .gz or .bz2 will be compressed.
56
    encoding : string (optional, default: 'utf-8')
57
       Encoding for text data.
58
    prettyprint : bool (optional, default: True)
59
       If True use line breaks and indenting in output XML.
60

61
    Examples
62
    --------
63
    >>> G = nx.path_graph(4)
64
    >>> nx.write_gexf(G, "test.gexf")
65

66
    # visualization data
67
    >>> G.nodes[0]['viz'] = {'size': 54}
68
    >>> G.nodes[0]['viz']['position'] = {'x' : 0, 'y' : 1}
69
    >>> G.nodes[0]['viz']['color'] = {'r' : 0, 'g' : 0, 'b' : 256}
70

71

72
    Notes
73
    -----
74
    This implementation does not support mixed graphs (directed and undirected
75
    edges together).
76

77
    The node id attribute is set to be the string of the node label.
78
    If you want to specify an id use set it as node data, e.g.
79
    node['a']['id']=1 to set the id of node 'a' to 1.
80

81
    References
82
    ----------
83
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
84
    .. [2] GEXF viz schema 1.1, https://gephi.org/gexf/1.1draft/viz
85
    """
86
    writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
87
                        version=version)
88
    writer.add_graph(G)
89
    writer.write(path)
90

    
91

    
92
def generate_gexf(G, encoding='utf-8', prettyprint=True, version='1.2draft'):
93
    """Generate lines of GEXF format representation of G.
94

95
    "GEXF (Graph Exchange XML Format) is a language for describing
96
    complex networks structures, their associated data and dynamics" [1]_.
97

98
    Parameters
99
    ----------
100
    G : graph
101
       A NetworkX graph
102
    encoding : string (optional, default: 'utf-8')
103
       Encoding for text data.
104
    prettyprint : bool (optional, default: True)
105
       If True use line breaks and indenting in output XML.
106
    version : string (default: 1.2draft)
107
       Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html).
108
       Supported values: "1.1draft", "1.2draft"
109

110

111
    Examples
112
    --------
113
    >>> G = nx.path_graph(4)
114
    >>> linefeed = chr(10) # linefeed=\n
115
    >>> s = linefeed.join(nx.generate_gexf(G))  # doctest: +SKIP
116
    >>> for line in nx.generate_gexf(G):  # doctest: +SKIP
117
    ...    print line
118

119
    Notes
120
    -----
121
    This implementation does not support mixed graphs (directed and undirected
122
    edges together).
123

124
    The node id attribute is set to be the string of the node label.
125
    If you want to specify an id use set it as node data, e.g.
126
    node['a']['id']=1 to set the id of node 'a' to 1.
127

128
    References
129
    ----------
130
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
131
    """
132
    writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint,
133
                        version=version)
134
    writer.add_graph(G)
135
    for line in str(writer).splitlines():
136
        yield line
137

    
138

    
139
@open_file(0, mode='rb')
140
def read_gexf(path, node_type=None, relabel=False, version='1.2draft'):
141
    """Read graph in GEXF format from path.
142

143
    "GEXF (Graph Exchange XML Format) is a language for describing
144
    complex networks structures, their associated data and dynamics" [1]_.
145

146
    Parameters
147
    ----------
148
    path : file or string
149
       File or file name to write.
150
       File names ending in .gz or .bz2 will be compressed.
151
    node_type: Python type (default: None)
152
       Convert node ids to this type if not None.
153
    relabel : bool (default: False)
154
       If True relabel the nodes to use the GEXF node "label" attribute
155
       instead of the node "id" attribute as the NetworkX node label.
156
    version : string (default: 1.2draft)
157
       Version of GEFX File Format (see https://gephi.org/gexf/format/schema.html).
158
       Supported values: "1.1draft", "1.2draft"
159

160
    Returns
161
    -------
162
    graph: NetworkX graph
163
        If no parallel edges are found a Graph or DiGraph is returned.
164
        Otherwise a MultiGraph or MultiDiGraph is returned.
165

166
    Notes
167
    -----
168
    This implementation does not support mixed graphs (directed and undirected
169
    edges together).
170

171
    References
172
    ----------
173
    .. [1] GEXF File Format, https://gephi.org/gexf/format/
174
    """
175
    reader = GEXFReader(node_type=node_type, version=version)
176
    if relabel:
177
        G = relabel_gexf_graph(reader(path))
178
    else:
179
        G = reader(path)
180
    return G
181

    
182

    
183
class GEXF(object):
184
    versions = {}
185
    d = {'NS_GEXF': "http://www.gexf.net/1.1draft",
186
         'NS_VIZ': "http://www.gexf.net/1.1draft/viz",
187
         'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
188
         'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.1draft',
189
                                     'http://www.gexf.net/1.1draft/gexf.xsd']),
190
         'VERSION': '1.1'}
191
    versions['1.1draft'] = d
192
    d = {'NS_GEXF': "http://www.gexf.net/1.2draft",
193
         'NS_VIZ': "http://www.gexf.net/1.2draft/viz",
194
         'NS_XSI': "http://www.w3.org/2001/XMLSchema-instance",
195
         'SCHEMALOCATION': ' '.join(['http://www.gexf.net/1.2draft',
196
                                     'http://www.gexf.net/1.2draft/gexf.xsd']),
197
         'VERSION': '1.2'}
198
    versions['1.2draft'] = d
199

    
200
    types = [(int, "integer"),
201
             (float, "float"),
202
             (float, "double"),
203
             (bool, "boolean"),
204
             (list, "string"),
205
             (dict, "string"),
206
             (int, "long"),
207
             (str, "liststring"),
208
             (str, "anyURI"),
209
             (str, "string")]
210

    
211
    # These additions to types allow writing numpy types
212
    try:
213
        import numpy as np
214
    except ImportError:
215
        pass
216
    else:
217
        # prepend so that python types are created upon read (last entry wins)
218
        types = [(np.float64, "float"), (np.float32, "float"),
219
                 (np.float16, "float"), (np.float_, "float"),
220
                 (np.int, "int"), (np.int8, "int"),
221
                 (np.int16, "int"), (np.int32, "int"),
222
                 (np.int64, "int"), (np.uint8, "int"),
223
                 (np.uint16, "int"), (np.uint32, "int"),
224
                 (np.uint64, "int"), (np.int_, "int"),
225
                 (np.intc, "int"), (np.intp, "int"),
226
                ] + types
227

    
228
    xml_type = dict(types)
229
    python_type = dict(reversed(a) for a in types)
230

    
231
    # http://www.w3.org/TR/xmlschema-2/#boolean
232
    convert_bool = {
233
        'true': True, 'false': False,
234
        'True': True, 'False': False,
235
        '0': False, 0: False,
236
        '1': True, 1: True
237
    }
238

    
239
    def set_version(self, version):
240
        d = self.versions.get(version)
241
        if d is None:
242
            raise nx.NetworkXError('Unknown GEXF version %s.' % version)
243
        self.NS_GEXF = d['NS_GEXF']
244
        self.NS_VIZ = d['NS_VIZ']
245
        self.NS_XSI = d['NS_XSI']
246
        self.SCHEMALOCATION = d['NS_XSI']
247
        self.VERSION = d['VERSION']
248
        self.version = version
249

    
250

    
251
class GEXFWriter(GEXF):
252
    # class for writing GEXF format files
253
    # use write_gexf() function
254
    def __init__(self, graph=None, encoding='utf-8', prettyprint=True,
255
                 version='1.2draft'):
256
        try:
257
            import xml.etree.ElementTree as ET
258
        except ImportError:
259
            raise ImportError('GEXF writer requires '
260
                              'xml.elementtree.ElementTree')
261
        self.prettyprint = prettyprint
262
        self.encoding = encoding
263
        self.set_version(version)
264
        self.xml = Element('gexf',
265
                           {'xmlns': self.NS_GEXF,
266
                            'xmlns:xsi': self.NS_XSI,
267
                            'xsi:schemaLocation': self.SCHEMALOCATION,
268
                            'version': self.VERSION})
269

    
270
        ET.register_namespace('viz', self.NS_VIZ)
271

    
272
        # counters for edge and attribute identifiers
273
        self.edge_id = itertools.count()
274
        self.attr_id = itertools.count()
275
        self.all_edge_ids = set()
276
        # default attributes are stored in dictionaries
277
        self.attr = {}
278
        self.attr['node'] = {}
279
        self.attr['edge'] = {}
280
        self.attr['node']['dynamic'] = {}
281
        self.attr['node']['static'] = {}
282
        self.attr['edge']['dynamic'] = {}
283
        self.attr['edge']['static'] = {}
284

    
285
        if graph is not None:
286
            self.add_graph(graph)
287

    
288
    def __str__(self):
289
        if self.prettyprint:
290
            self.indent(self.xml)
291
        s = tostring(self.xml).decode(self.encoding)
292
        return s
293

    
294
    def add_graph(self, G):
295
        # first pass through G collecting edge ids
296
        for u, v, dd in G.edges(data=True):
297
            eid = dd.get('id')
298
            if eid is not None:
299
                self.all_edge_ids.add(make_str(eid))
300
        # set graph attributes
301
        if G.graph.get('mode') == 'dynamic':
302
            mode = 'dynamic'
303
        else:
304
            mode = 'static'
305
        # Add a graph element to the XML
306
        if G.is_directed():
307
            default = 'directed'
308
        else:
309
            default = 'undirected'
310
        name = G.graph.get('name', '')
311
        graph_element = Element('graph', defaultedgetype=default, mode=mode,
312
                                name=name)
313
        self.graph_element = graph_element
314
        self.add_meta(G, graph_element)
315
        self.add_nodes(G, graph_element)
316
        self.add_edges(G, graph_element)
317
        self.xml.append(graph_element)
318

    
319
    def add_meta(self, G, graph_element):
320
        # add meta element with creator and date
321
        meta_element = Element('meta')
322
        SubElement(meta_element, 'creator').text = 'NetworkX {}'.format(nx.__version__)
323
        SubElement(meta_element, 'lastmodified').text = time.strftime('%d/%m/%Y')
324
        graph_element.append(meta_element)
325

    
326
    def add_nodes(self, G, graph_element):
327
        nodes_element = Element('nodes')
328
        for node, data in G.nodes(data=True):
329
            node_data = data.copy()
330
            node_id = make_str(node_data.pop('id', node))
331
            kw = {'id': node_id}
332
            label = make_str(node_data.pop('label', node))
333
            kw['label'] = label
334
            try:
335
                pid = node_data.pop('pid')
336
                kw['pid'] = make_str(pid)
337
            except KeyError:
338
                pass
339
            try:
340
                start = node_data.pop('start')
341
                kw['start'] = make_str(start)
342
                self.alter_graph_mode_timeformat(start)
343
            except KeyError:
344
                pass
345
            try:
346
                end = node_data.pop('end')
347
                kw['end'] = make_str(end)
348
                self.alter_graph_mode_timeformat(end)
349
            except KeyError:
350
                pass
351
            # add node element with attributes
352
            node_element = Element('node', **kw)
353
            # add node element and attr subelements
354
            default = G.graph.get('node_default', {})
355
            node_data = self.add_parents(node_element, node_data)
356
            if self.version == '1.1':
357
                node_data = self.add_slices(node_element, node_data)
358
            else:
359
                node_data = self.add_spells(node_element, node_data)
360
            node_data = self.add_viz(node_element, node_data)
361
            node_data = self.add_attributes('node', node_element,
362
                                            node_data, default)
363
            nodes_element.append(node_element)
364
        graph_element.append(nodes_element)
365

    
366
    def add_edges(self, G, graph_element):
367
        def edge_key_data(G):
368
            # helper function to unify multigraph and graph edge iterator
369
            if G.is_multigraph():
370
                for u, v, key, data in G.edges(data=True, keys=True):
371
                    edge_data = data.copy()
372
                    edge_data.update(key=key)
373
                    edge_id = edge_data.pop('id', None)
374
                    if edge_id is None:
375
                        edge_id = next(self.edge_id)
376
                        while make_str(edge_id) in self.all_edge_ids:
377
                            edge_id = next(self.edge_id)
378
                        self.all_edge_ids.add(make_str(edge_id))
379
                    yield u, v, edge_id, edge_data
380
            else:
381
                for u, v, data in G.edges(data=True):
382
                    edge_data = data.copy()
383
                    edge_id = edge_data.pop('id', None)
384
                    if edge_id is None:
385
                        edge_id = next(self.edge_id)
386
                        while make_str(edge_id) in self.all_edge_ids:
387
                            edge_id = next(self.edge_id)
388
                        self.all_edge_ids.add(make_str(edge_id))
389
                    yield u, v, edge_id, edge_data
390
        edges_element = Element('edges')
391
        for u, v, key, edge_data in edge_key_data(G):
392
            kw = {'id': make_str(key)}
393
            try:
394
                edge_weight = edge_data.pop('weight')
395
                kw['weight'] = make_str(edge_weight)
396
            except KeyError:
397
                pass
398
            try:
399
                edge_type = edge_data.pop('type')
400
                kw['type'] = make_str(edge_type)
401
            except KeyError:
402
                pass
403
            try:
404
                start = edge_data.pop('start')
405
                kw['start'] = make_str(start)
406
                self.alter_graph_mode_timeformat(start)
407
            except KeyError:
408
                pass
409
            try:
410
                end = edge_data.pop('end')
411
                kw['end'] = make_str(end)
412
                self.alter_graph_mode_timeformat(end)
413
            except KeyError:
414
                pass
415
            source_id = make_str(G.nodes[u].get('id', u))
416
            target_id = make_str(G.nodes[v].get('id', v))
417
            edge_element = Element('edge',
418
                                   source=source_id, target=target_id, **kw)
419
            default = G.graph.get('edge_default', {})
420
            if self.version == '1.1':
421
                edge_data = self.add_slices(edge_element, edge_data)
422
            else:
423
                edge_data = self.add_spells(edge_element, edge_data)
424
            edge_data = self.add_viz(edge_element, edge_data)
425
            edge_data = self.add_attributes('edge', edge_element,
426
                                            edge_data, default)
427
            edges_element.append(edge_element)
428
        graph_element.append(edges_element)
429

    
430
    def add_attributes(self, node_or_edge, xml_obj, data, default):
431
        # Add attrvalues to node or edge
432
        attvalues = Element('attvalues')
433
        if len(data) == 0:
434
            return data
435
        mode = 'static'
436
        for k, v in data.items():
437
            # rename generic multigraph key to avoid any name conflict
438
            if k == 'key':
439
                k = 'networkx_key'
440
            val_type = type(v)
441
            if val_type not in self.xml_type:
442
                raise TypeError('attribute value type is not allowed: %s' % val_type)
443
            if isinstance(v, list):
444
                # dynamic data
445
                for val, start, end in v:
446
                    val_type = type(val)
447
                    if start is not None or end is not None:
448
                        mode = 'dynamic'
449
                        self.alter_graph_mode_timeformat(start)
450
                        self.alter_graph_mode_timeformat(end)
451
                        break
452
                attr_id = self.get_attr_id(make_str(k), self.xml_type[val_type],
453
                                           node_or_edge, default, mode)
454
                for val, start, end in v:
455
                    e = Element('attvalue')
456
                    e.attrib['for'] = attr_id
457
                    e.attrib['value'] = make_str(val)
458
                    if start is not None:
459
                        e.attrib['start'] = make_str(start)
460
                    if end is not None:
461
                        e.attrib['end'] = make_str(end)
462
                    attvalues.append(e)
463
            else:
464
                # static data
465
                mode = 'static'
466
                attr_id = self.get_attr_id(make_str(k), self.xml_type[val_type],
467
                                           node_or_edge, default, mode)
468
                e = Element('attvalue')
469
                e.attrib['for'] = attr_id
470
                if isinstance(v, bool):
471
                    e.attrib['value'] = make_str(v).lower()
472
                else:
473
                    e.attrib['value'] = make_str(v)
474
                attvalues.append(e)
475
        xml_obj.append(attvalues)
476
        return data
477

    
478
    def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
479
        # find the id of the attribute or generate a new id
480
        try:
481
            return self.attr[edge_or_node][mode][title]
482
        except KeyError:
483
            # generate new id
484
            new_id = str(next(self.attr_id))
485
            self.attr[edge_or_node][mode][title] = new_id
486
            attr_kwargs = {'id': new_id, 'title': title, 'type': attr_type}
487
            attribute = Element('attribute', **attr_kwargs)
488
            # add subelement for data default value if present
489
            default_title = default.get(title)
490
            if default_title is not None:
491
                default_element = Element('default')
492
                default_element.text = make_str(default_title)
493
                attribute.append(default_element)
494
            # new insert it into the XML
495
            attributes_element = None
496
            for a in self.graph_element.findall('attributes'):
497
                # find existing attributes element by class and mode
498
                a_class = a.get('class')
499
                a_mode = a.get('mode', 'static')
500
                if a_class == edge_or_node and a_mode == mode:
501
                    attributes_element = a
502
            if attributes_element is None:
503
                # create new attributes element
504
                attr_kwargs = {'mode': mode, 'class': edge_or_node}
505
                attributes_element = Element('attributes', **attr_kwargs)
506
                self.graph_element.insert(0, attributes_element)
507
            attributes_element.append(attribute)
508
        return new_id
509

    
510
    def add_viz(self, element, node_data):
511
        viz = node_data.pop('viz', False)
512
        if viz:
513
            color = viz.get('color')
514
            if color is not None:
515
                if self.VERSION == '1.1':
516
                    e = Element('{%s}color' % self.NS_VIZ,
517
                                r=str(color.get('r')),
518
                                g=str(color.get('g')),
519
                                b=str(color.get('b')))
520
                else:
521
                    e = Element('{%s}color' % self.NS_VIZ,
522
                                r=str(color.get('r')),
523
                                g=str(color.get('g')),
524
                                b=str(color.get('b')),
525
                                a=str(color.get('a')))
526
                element.append(e)
527

    
528
            size = viz.get('size')
529
            if size is not None:
530
                e = Element('{%s}size' % self.NS_VIZ, value=str(size))
531
                element.append(e)
532

    
533
            thickness = viz.get('thickness')
534
            if thickness is not None:
535
                e = Element('{%s}thickness' % self.NS_VIZ, value=str(thickness))
536
                element.append(e)
537

    
538
            shape = viz.get('shape')
539
            if shape is not None:
540
                if shape.startswith('http'):
541
                    e = Element('{%s}shape' % self.NS_VIZ,
542
                                value='image', uri=str(shape))
543
                else:
544
                    e = Element('{%s}shape' % self.NS_VIZ, value=str(shape))
545
                element.append(e)
546

    
547
            position = viz.get('position')
548
            if position is not None:
549
                e = Element('{%s}position' % self.NS_VIZ,
550
                            x=str(position.get('x')),
551
                            y=str(position.get('y')),
552
                            z=str(position.get('z')))
553
                element.append(e)
554
        return node_data
555

    
556
    def add_parents(self, node_element, node_data):
557
        parents = node_data.pop('parents', False)
558
        if parents:
559
            parents_element = Element('parents')
560
            for p in parents:
561
                e = Element('parent')
562
                e.attrib['for'] = str(p)
563
                parents_element.append(e)
564
            node_element.append(parents_element)
565
        return node_data
566

    
567
    def add_slices(self, node_or_edge_element, node_or_edge_data):
568
        slices = node_or_edge_data.pop('slices', False)
569
        if slices:
570
            slices_element = Element('slices')
571
            for start, end in slices:
572
                e = Element('slice', start=str(start), end=str(end))
573
                slices_element.append(e)
574
            node_or_edge_element.append(slices_element)
575
        return node_or_edge_data
576

    
577
    def add_spells(self, node_or_edge_element, node_or_edge_data):
578
        spells = node_or_edge_data.pop('spells', False)
579
        if spells:
580
            spells_element = Element('spells')
581
            for start, end in spells:
582
                e = Element('spell')
583
                if start is not None:
584
                    e.attrib['start'] = make_str(start)
585
                    self.alter_graph_mode_timeformat(start)
586
                if end is not None:
587
                    e.attrib['end'] = make_str(end)
588
                    self.alter_graph_mode_timeformat(end)
589
                spells_element.append(e)
590
            node_or_edge_element.append(spells_element)
591
        return node_or_edge_data
592

    
593
    def alter_graph_mode_timeformat(self, start_or_end):
594
        # if 'start' or 'end' appears, alter Graph mode to dynamic and set timeformat
595
        if self.graph_element.get('mode') == 'static':
596
            if start_or_end is not None:
597
                if isinstance(start_or_end, str):
598
                    timeformat = 'date'
599
                elif isinstance(start_or_end, float):
600
                    timeformat = 'double'
601
                elif isinstance(start_or_end, int):
602
                    timeformat = 'long'
603
                else:
604
                    raise nx.NetworkXError(
605
                        'timeformat should be of the type int, float or str')
606
                self.graph_element.set('timeformat', timeformat)
607
                self.graph_element.set('mode', 'dynamic')
608

    
609
    def write(self, fh):
610
        # Serialize graph G in GEXF to the open fh
611
        if self.prettyprint:
612
            self.indent(self.xml)
613
        document = ElementTree(self.xml)
614
        document.write(fh, encoding=self.encoding, xml_declaration=True)
615

    
616
    def indent(self, elem, level=0):
617
        # in-place prettyprint formatter
618
        i = "\n" + "  " * level
619
        if len(elem):
620
            if not elem.text or not elem.text.strip():
621
                elem.text = i + "  "
622
            if not elem.tail or not elem.tail.strip():
623
                elem.tail = i
624
            for elem in elem:
625
                self.indent(elem, level + 1)
626
            if not elem.tail or not elem.tail.strip():
627
                elem.tail = i
628
        else:
629
            if level and (not elem.tail or not elem.tail.strip()):
630
                elem.tail = i
631

    
632

    
633
class GEXFReader(GEXF):
634
    # Class to read GEXF format files
635
    # use read_gexf() function
636
    def __init__(self, node_type=None, version='1.2draft'):
637
        try:
638
            import xml.etree.ElementTree
639
        except ImportError:
640
            raise ImportError('GEXF reader requires '
641
                              'xml.elementtree.ElementTree.')
642
        self.node_type = node_type
643
        # assume simple graph and test for multigraph on read
644
        self.simple_graph = True
645
        self.set_version(version)
646

    
647
    def __call__(self, stream):
648
        self.xml = ElementTree(file=stream)
649
        g = self.xml.find('{%s}graph' % self.NS_GEXF)
650
        if g is not None:
651
            return self.make_graph(g)
652
        # try all the versions
653
        for version in self.versions:
654
            self.set_version(version)
655
            g = self.xml.find('{%s}graph' % self.NS_GEXF)
656
            if g is not None:
657
                return self.make_graph(g)
658
        raise nx.NetworkXError('No <graph> element in GEXF file.')
659

    
660
    def make_graph(self, graph_xml):
661
        # start with empty DiGraph or MultiDiGraph
662
        edgedefault = graph_xml.get('defaultedgetype', None)
663
        if edgedefault == 'directed':
664
            G = nx.MultiDiGraph()
665
        else:
666
            G = nx.MultiGraph()
667

    
668
        # graph attributes
669
        graph_name = graph_xml.get('name', '')
670
        if graph_name != '':
671
            G.graph['name'] = graph_name
672
        graph_start = graph_xml.get('start')
673
        if graph_start is not None:
674
            G.graph['start'] = graph_start
675
        graph_end = graph_xml.get('end')
676
        if graph_end is not None:
677
            G.graph['end'] = graph_end
678
        graph_mode = graph_xml.get('mode', '')
679
        if graph_mode == 'dynamic':
680
            G.graph['mode'] = 'dynamic'
681
        else:
682
            G.graph['mode'] = 'static'
683

    
684
        # timeformat
685
        self.timeformat = graph_xml.get('timeformat')
686
        if self.timeformat == 'date':
687
            self.timeformat = 'string'
688

    
689
        # node and edge attributes
690
        attributes_elements = graph_xml.findall('{%s}attributes' % self.NS_GEXF)
691
        # dictionaries to hold attributes and attribute defaults
692
        node_attr = {}
693
        node_default = {}
694
        edge_attr = {}
695
        edge_default = {}
696
        for a in attributes_elements:
697
            attr_class = a.get('class')
698
            if attr_class == 'node':
699
                na, nd = self.find_gexf_attributes(a)
700
                node_attr.update(na)
701
                node_default.update(nd)
702
                G.graph['node_default'] = node_default
703
            elif attr_class == 'edge':
704
                ea, ed = self.find_gexf_attributes(a)
705
                edge_attr.update(ea)
706
                edge_default.update(ed)
707
                G.graph['edge_default'] = edge_default
708
            else:
709
                raise  # unknown attribute class
710

    
711
        # Hack to handle Gephi0.7beta bug
712
        # add weight attribute
713
        ea = {'weight': {'type': 'double', 'mode': 'static', 'title': 'weight'}}
714
        ed = {}
715
        edge_attr.update(ea)
716
        edge_default.update(ed)
717
        G.graph['edge_default'] = edge_default
718

    
719
        # add nodes
720
        nodes_element = graph_xml.find('{%s}nodes' % self.NS_GEXF)
721
        if nodes_element is not None:
722
            for node_xml in nodes_element.findall('{%s}node' % self.NS_GEXF):
723
                self.add_node(G, node_xml, node_attr)
724

    
725
        # add edges
726
        edges_element = graph_xml.find('{%s}edges' % self.NS_GEXF)
727
        if edges_element is not None:
728
            for edge_xml in edges_element.findall('{%s}edge' % self.NS_GEXF):
729
                self.add_edge(G, edge_xml, edge_attr)
730

    
731
        # switch to Graph or DiGraph if no parallel edges were found.
732
        if self.simple_graph:
733
            if G.is_directed():
734
                G = nx.DiGraph(G)
735
            else:
736
                G = nx.Graph(G)
737
        return G
738

    
739
    def add_node(self, G, node_xml, node_attr, node_pid=None):
740
        # add a single node with attributes to the graph
741

    
742
        # get attributes and subattributues for node
743
        data = self.decode_attr_elements(node_attr, node_xml)
744
        data = self.add_parents(data, node_xml)  # add any parents
745
        if self.version == '1.1':
746
            data = self.add_slices(data, node_xml)  # add slices
747
        else:
748
            data = self.add_spells(data, node_xml)  # add spells
749
        data = self.add_viz(data, node_xml)  # add viz
750
        data = self.add_start_end(data, node_xml)  # add start/end
751

    
752
        # find the node id and cast it to the appropriate type
753
        node_id = node_xml.get('id')
754
        if self.node_type is not None:
755
            node_id = self.node_type(node_id)
756

    
757
        # every node should have a label
758
        node_label = node_xml.get('label')
759
        data['label'] = node_label
760

    
761
        # parent node id
762
        node_pid = node_xml.get('pid', node_pid)
763
        if node_pid is not None:
764
            data['pid'] = node_pid
765

    
766
        # check for subnodes, recursive
767
        subnodes = node_xml.find('{%s}nodes' % self.NS_GEXF)
768
        if subnodes is not None:
769
            for node_xml in subnodes.findall('{%s}node' % self.NS_GEXF):
770
                self.add_node(G, node_xml, node_attr, node_pid=node_id)
771

    
772
        G.add_node(node_id, **data)
773

    
774
    def add_start_end(self, data, xml):
775
        # start and end times
776
        ttype = self.timeformat
777
        node_start = xml.get('start')
778
        if node_start is not None:
779
            data['start'] = self.python_type[ttype](node_start)
780
        node_end = xml.get('end')
781
        if node_end is not None:
782
            data['end'] = self.python_type[ttype](node_end)
783
        return data
784

    
785
    def add_viz(self, data, node_xml):
786
        # add viz element for node
787
        viz = {}
788
        color = node_xml.find('{%s}color' % self.NS_VIZ)
789
        if color is not None:
790
            if self.VERSION == '1.1':
791
                viz['color'] = {'r': int(color.get('r')),
792
                                'g': int(color.get('g')),
793
                                'b': int(color.get('b'))}
794
            else:
795
                viz['color'] = {'r': int(color.get('r')),
796
                                'g': int(color.get('g')),
797
                                'b': int(color.get('b')),
798
                                'a': float(color.get('a', 1))}
799

    
800
        size = node_xml.find('{%s}size' % self.NS_VIZ)
801
        if size is not None:
802
            viz['size'] = float(size.get('value'))
803

    
804
        thickness = node_xml.find('{%s}thickness' % self.NS_VIZ)
805
        if thickness is not None:
806
            viz['thickness'] = float(thickness.get('value'))
807

    
808
        shape = node_xml.find('{%s}shape' % self.NS_VIZ)
809
        if shape is not None:
810
            viz['shape'] = shape.get('shape')
811
            if viz['shape'] == 'image':
812
                viz['shape'] = shape.get('uri')
813

    
814
        position = node_xml.find('{%s}position' % self.NS_VIZ)
815
        if position is not None:
816
            viz['position'] = {'x': float(position.get('x', 0)),
817
                               'y': float(position.get('y', 0)),
818
                               'z': float(position.get('z', 0))}
819

    
820
        if len(viz) > 0:
821
            data['viz'] = viz
822
        return data
823

    
824
    def add_parents(self, data, node_xml):
825
        parents_element = node_xml.find('{%s}parents' % self.NS_GEXF)
826
        if parents_element is not None:
827
            data['parents'] = []
828
            for p in parents_element.findall('{%s}parent' % self.NS_GEXF):
829
                parent = p.get('for')
830
                data['parents'].append(parent)
831
        return data
832

    
833
    def add_slices(self, data, node_or_edge_xml):
834
        slices_element = node_or_edge_xml.find('{%s}slices' % self.NS_GEXF)
835
        if slices_element is not None:
836
            data['slices'] = []
837
            for s in slices_element.findall('{%s}slice' % self.NS_GEXF):
838
                start = s.get('start')
839
                end = s.get('end')
840
                data['slices'].append((start, end))
841
        return data
842

    
843
    def add_spells(self, data, node_or_edge_xml):
844
        spells_element = node_or_edge_xml.find('{%s}spells' % self.NS_GEXF)
845
        if spells_element is not None:
846
            data['spells'] = []
847
            ttype = self.timeformat
848
            for s in spells_element.findall('{%s}spell' % self.NS_GEXF):
849
                start = self.python_type[ttype](s.get('start'))
850
                end = self.python_type[ttype](s.get('end'))
851
                data['spells'].append((start, end))
852
        return data
853

    
854
    def add_edge(self, G, edge_element, edge_attr):
855
        # add an edge to the graph
856

    
857
        # raise error if we find mixed directed and undirected edges
858
        edge_direction = edge_element.get('type')
859
        if G.is_directed() and edge_direction == 'undirected':
860
            raise nx.NetworkXError(
861
                'Undirected edge found in directed graph.')
862
        if (not G.is_directed()) and edge_direction == 'directed':
863
            raise nx.NetworkXError(
864
                'Directed edge found in undirected graph.')
865

    
866
        # Get source and target and recast type if required
867
        source = edge_element.get('source')
868
        target = edge_element.get('target')
869
        if self.node_type is not None:
870
            source = self.node_type(source)
871
            target = self.node_type(target)
872

    
873
        data = self.decode_attr_elements(edge_attr, edge_element)
874
        data = self.add_start_end(data, edge_element)
875

    
876
        if self.version == '1.1':
877
            data = self.add_slices(data, edge_element)  # add slices
878
        else:
879
            data = self.add_spells(data, edge_element)  # add spells
880

    
881
        # GEXF stores edge ids as an attribute
882
        # NetworkX uses them as keys in multigraphs
883
        # if networkx_key is not specified as an attribute
884
        edge_id = edge_element.get('id')
885
        if edge_id is not None:
886
            data['id'] = edge_id
887

    
888
        # check if there is a 'multigraph_key' and use that as edge_id
889
        multigraph_key = data.pop('networkx_key', None)
890
        if multigraph_key is not None:
891
            edge_id = multigraph_key
892

    
893
        weight = edge_element.get('weight')
894
        if weight is not None:
895
            data['weight'] = float(weight)
896

    
897
        edge_label = edge_element.get('label')
898
        if edge_label is not None:
899
            data['label'] = edge_label
900

    
901
        if G.has_edge(source, target):
902
            # seen this edge before - this is a multigraph
903
            self.simple_graph = False
904
        G.add_edge(source, target, key=edge_id, **data)
905
        if edge_direction == 'mutual':
906
            G.add_edge(target, source, key=edge_id, **data)
907

    
908
    def decode_attr_elements(self, gexf_keys, obj_xml):
909
        # Use the key information to decode the attr XML
910
        attr = {}
911
        # look for outer '<attvalues>' element
912
        attr_element = obj_xml.find('{%s}attvalues' % self.NS_GEXF)
913
        if attr_element is not None:
914
            # loop over <attvalue> elements
915
            for a in attr_element.findall('{%s}attvalue' % self.NS_GEXF):
916
                key = a.get('for')  # for is required
917
                try:  # should be in our gexf_keys dictionary
918
                    title = gexf_keys[key]['title']
919
                except KeyError:
920
                    raise nx.NetworkXError('No attribute defined for=%s.' % key)
921
                atype = gexf_keys[key]['type']
922
                value = a.get('value')
923
                if atype == 'boolean':
924
                    value = self.convert_bool[value]
925
                else:
926
                    value = self.python_type[atype](value)
927
                if gexf_keys[key]['mode'] == 'dynamic':
928
                    # for dynamic graphs use list of three-tuples
929
                    # [(value1,start1,end1), (value2,start2,end2), etc]
930
                    ttype = self.timeformat
931
                    start = self.python_type[ttype](a.get('start'))
932
                    end = self.python_type[ttype](a.get('end'))
933
                    if title in attr:
934
                        attr[title].append((value, start, end))
935
                    else:
936
                        attr[title] = [(value, start, end)]
937
                else:
938
                    # for static graphs just assign the value
939
                    attr[title] = value
940
        return attr
941

    
942
    def find_gexf_attributes(self, attributes_element):
943
        # Extract all the attributes and defaults
944
        attrs = {}
945
        defaults = {}
946
        mode = attributes_element.get('mode')
947
        for k in attributes_element.findall('{%s}attribute' % self.NS_GEXF):
948
            attr_id = k.get('id')
949
            title = k.get('title')
950
            atype = k.get('type')
951
            attrs[attr_id] = {'title': title, 'type': atype, 'mode': mode}
952
            # check for the 'default' subelement of key element and add
953
            default = k.find('{%s}default' % self.NS_GEXF)
954
            if default is not None:
955
                if atype == 'boolean':
956
                    value = self.convert_bool[default.text]
957
                else:
958
                    value = self.python_type[atype](default.text)
959
                defaults[title] = value
960
        return attrs, defaults
961

    
962

    
963
def relabel_gexf_graph(G):
964
    """Relabel graph using "label" node keyword for node label.
965

966
    Parameters
967
    ----------
968
    G : graph
969
       A NetworkX graph read from GEXF data
970

971
    Returns
972
    -------
973
    H : graph
974
      A NetworkX graph with relabed nodes
975

976
    Raises
977
    ------
978
    NetworkXError
979
        If node labels are missing or not unique while relabel=True.
980

981
    Notes
982
    -----
983
    This function relabels the nodes in a NetworkX graph with the
984
    "label" attribute.  It also handles relabeling the specific GEXF
985
    node attributes "parents", and "pid".
986
    """
987
    # build mapping of node labels, do some error checking
988
    try:
989
        mapping = [(u, G.nodes[u]['label']) for u in G]
990
    except KeyError:
991
        raise nx.NetworkXError('Failed to relabel nodes: '
992
                               'missing node labels found. '
993
                               'Use relabel=False.')
994
    x, y = zip(*mapping)
995
    if len(set(y)) != len(G):
996
        raise nx.NetworkXError('Failed to relabel nodes: '
997
                               'duplicate node labels found. '
998
                               'Use relabel=False.')
999
    mapping = dict(mapping)
1000
    H = nx.relabel_nodes(G, mapping)
1001
    # relabel attributes
1002
    for n in G:
1003
        m = mapping[n]
1004
        H.nodes[m]['id'] = n
1005
        H.nodes[m].pop('label')
1006
        if 'pid' in H.nodes[m]:
1007
            H.nodes[m]['pid'] = mapping[G.nodes[n]['pid']]
1008
        if 'parents' in H.nodes[m]:
1009
            H.nodes[m]['parents'] = [mapping[p] for p in G.nodes[n]['parents']]
1010
    return H
1011

    
1012

    
1013
# fixture for nose tests
1014
def setup_module(module):
1015
    from nose import SkipTest
1016
    try:
1017
        import xml.etree.cElementTree
1018
    except:
1019
        raise SkipTest('xml.etree.cElementTree not available.')
1020

    
1021

    
1022
# fixture for nose tests
1023
def teardown_module(module):
1024
    import os
1025
    try:
1026
        os.unlink('test.gexf')
1027
    except:
1028
        pass