Statistics
| Branch: | Revision:

iof-tools / networkxMiCe / networkx-master / networkx / readwrite / graphml.py @ 5cef0f13

History | View | Annotate | Download (34.9 KB)

1
#    Copyright (C) 2008-2019 by
2
#    Aric Hagberg <hagberg@lanl.gov>
3
#    Dan Schult <dschult@colgate.edu>
4
#    Pieter Swart <swart@lanl.gov>
5
#    All rights reserved.
6
#    BSD license.
7
#
8
# Authors: Salim Fadhley
9
#          Aric Hagberg (hagberg@lanl.gov)
10
"""
11
*******
12
GraphML
13
*******
14
Read and write graphs in GraphML format.
15

16
This implementation does not support mixed graphs (directed and unidirected
17
edges together), hyperedges, nested graphs, or ports.
18

19
"GraphML is a comprehensive and easy-to-use file format for graphs. It
20
consists of a language core to describe the structural properties of a
21
graph and a flexible extension mechanism to add application-specific
22
data. Its main features include support of
23

24
    * directed, undirected, and mixed graphs,
25
    * hypergraphs,
26
    * hierarchical graphs,
27
    * graphical representations,
28
    * references to external data,
29
    * application-specific attribute data, and
30
    * light-weight parsers.
31

32
Unlike many other file formats for graphs, GraphML does not use a
33
custom syntax. Instead, it is based on XML and hence ideally suited as
34
a common denominator for all kinds of services generating, archiving,
35
or processing graphs."
36

37
http://graphml.graphdrawing.org/
38

39
Format
40
------
41
GraphML is an XML format.  See
42
http://graphml.graphdrawing.org/specification.html for the specification and
43
http://graphml.graphdrawing.org/primer/graphml-primer.html
44
for examples.
45
"""
46
import warnings
47
from collections import defaultdict
48

    
49
try:
50
    from xml.etree.cElementTree import Element, ElementTree
51
    from xml.etree.cElementTree import tostring, fromstring
52
except ImportError:
53
    try:
54
        from xml.etree.ElementTree import Element, ElementTree
55
        from xml.etree.ElementTree import tostring, fromstring
56
    except ImportError:
57
        pass
58

    
59
try:
60
    import lxml.etree as lxmletree
61
except ImportError:
62
    lxmletree = None
63

    
64
import networkx as nx
65
from networkx.utils import open_file, make_str
66

    
67
__all__ = ['write_graphml', 'read_graphml', 'generate_graphml',
68
           'write_graphml_xml', 'write_graphml_lxml',
69
           'parse_graphml', 'GraphMLWriter', 'GraphMLReader']
70

    
71

    
72
@open_file(1, mode='wb')
73
def write_graphml_xml(G, path, encoding='utf-8', prettyprint=True,
74
                      infer_numeric_types=False):
75
    """Write G in GraphML XML format to path
76

77
    Parameters
78
    ----------
79
    G : graph
80
       A networkx graph
81
    path : file or string
82
       File or filename to write.
83
       Filenames ending in .gz or .bz2 will be compressed.
84
    encoding : string (optional)
85
       Encoding for text data.
86
    prettyprint : bool (optional)
87
       If True use line breaks and indenting in output XML.
88
    infer_numeric_types : boolean
89
       Determine if numeric types should be generalized.
90
       For example, if edges have both int and float 'weight' attributes,
91
       we infer in GraphML that both are floats.
92

93
    Examples
94
    --------
95
    >>> G = nx.path_graph(4)
96
    >>> nx.write_graphml(G, "test.graphml")
97

98
    Notes
99
    -----
100
    It may be a good idea in Python2 to convert strings to unicode
101
    before giving the graph to write_gml. At least the strings with
102
    either many characters to escape.
103

104
    This implementation does not support mixed graphs (directed
105
    and unidirected edges together) hyperedges, nested graphs, or ports.
106
    """
107
    writer = GraphMLWriter(encoding=encoding, prettyprint=prettyprint,
108
                           infer_numeric_types=infer_numeric_types)
109
    writer.add_graph_element(G)
110
    writer.dump(path)
111

    
112

    
113
@open_file(1, mode='wb')
114
def write_graphml_lxml(G, path, encoding='utf-8', prettyprint=True,
115
                       infer_numeric_types=False):
116
    """Write G in GraphML XML format to path
117

118
    This function uses the LXML framework and should be faster than
119
    the version using the xml library.
120

121
    Parameters
122
    ----------
123
    G : graph
124
       A networkx graph
125
    path : file or string
126
       File or filename to write.
127
       Filenames ending in .gz or .bz2 will be compressed.
128
    encoding : string (optional)
129
       Encoding for text data.
130
    prettyprint : bool (optional)
131
       If True use line breaks and indenting in output XML.
132
    infer_numeric_types : boolean
133
       Determine if numeric types should be generalized.
134
       For example, if edges have both int and float 'weight' attributes,
135
       we infer in GraphML that both are floats.
136

137
    Examples
138
    --------
139
    >>> G = nx.path_graph(4)
140
    >>> nx.write_graphml_lxml(G, "fourpath.graphml")  # doctest: +SKIP
141

142
    Notes
143
    -----
144
    This implementation does not support mixed graphs (directed
145
    and unidirected edges together) hyperedges, nested graphs, or ports.
146
    """
147
    writer = GraphMLWriterLxml(path, graph=G, encoding=encoding,
148
                               prettyprint=prettyprint,
149
                               infer_numeric_types=infer_numeric_types)
150
    writer.dump()
151

    
152

    
153
def generate_graphml(G, encoding='utf-8', prettyprint=True):
154
    """Generate GraphML lines for G
155

156
    Parameters
157
    ----------
158
    G : graph
159
       A networkx graph
160
    encoding : string (optional)
161
       Encoding for text data.
162
    prettyprint : bool (optional)
163
       If True use line breaks and indenting in output XML.
164

165
    Examples
166
    --------
167
    >>> G = nx.path_graph(4)
168
    >>> linefeed = chr(10)  # linefeed = \n
169
    >>> s = linefeed.join(nx.generate_graphml(G))  # doctest: +SKIP
170
    >>> for line in nx.generate_graphml(G):  # doctest: +SKIP
171
    ...    print(line)
172

173
    Notes
174
    -----
175
    This implementation does not support mixed graphs (directed and unidirected
176
    edges together) hyperedges, nested graphs, or ports.
177
    """
178
    writer = GraphMLWriter(encoding=encoding, prettyprint=prettyprint)
179
    writer.add_graph_element(G)
180
    for line in str(writer).splitlines():
181
        yield line
182

    
183

    
184
@open_file(0, mode='rb')
185
def read_graphml(path, node_type=str, edge_key_type=int):
186
    """Read graph in GraphML format from path.
187

188
    Parameters
189
    ----------
190
    path : file or string
191
       File or filename to write.
192
       Filenames ending in .gz or .bz2 will be compressed.
193

194
    node_type: Python type (default: str)
195
       Convert node ids to this type
196

197
    edge_key_type: Python type (default: int)
198
       Convert graphml edge ids to this type as key of multi-edges
199

200

201
    Returns
202
    -------
203
    graph: NetworkX graph
204
        If no parallel edges are found a Graph or DiGraph is returned.
205
        Otherwise a MultiGraph or MultiDiGraph is returned.
206

207
    Notes
208
    -----
209
    Default node and edge attributes are not propagated to each node and edge.
210
    They can be obtained from `G.graph` and applied to node and edge attributes
211
    if desired using something like this:
212

213
    >>> default_color = G.graph['node_default']['color']  # doctest: +SKIP
214
    >>> for node, data in G.nodes(data=True):  # doctest: +SKIP
215
    ...     if 'color' not in data:
216
    ...         data['color']=default_color
217
    >>> default_color = G.graph['edge_default']['color']  # doctest: +SKIP
218
    >>> for u, v, data in G.edges(data=True):  # doctest: +SKIP
219
    ...     if 'color' not in data:
220
    ...         data['color']=default_color
221

222
    This implementation does not support mixed graphs (directed and unidirected
223
    edges together), hypergraphs, nested graphs, or ports.
224

225
    For multigraphs the GraphML edge "id" will be used as the edge
226
    key.  If not specified then they "key" attribute will be used.  If
227
    there is no "key" attribute a default NetworkX multigraph edge key
228
    will be provided.
229

230
    Files with the yEd "yfiles" extension will can be read but the graphics
231
    information is discarded.
232

233
    yEd compressed files ("file.graphmlz" extension) can be read by renaming
234
    the file to "file.graphml.gz".
235

236
    """
237
    reader = GraphMLReader(node_type=node_type, edge_key_type=edge_key_type)
238
    # need to check for multiple graphs
239
    glist = list(reader(path=path))
240
    if len(glist) == 0:
241
        # If no graph comes back, try looking for an incomplete header
242
        header = b'<graphml xmlns="http://graphml.graphdrawing.org/xmlns">'
243
        path.seek(0)
244
        old_bytes = path.read()
245
        new_bytes = old_bytes.replace(b'<graphml>', header)
246
        glist = list(reader(string=new_bytes))
247
        if len(glist) == 0:
248
            raise nx.NetworkXError('file not successfully read as graphml')
249
    return glist[0]
250

    
251

    
252
def parse_graphml(graphml_string, node_type=str):
253
    """Read graph in GraphML format from string.
254

255
    Parameters
256
    ----------
257
    graphml_string : string
258
       String containing graphml information
259
       (e.g., contents of a graphml file).
260

261
    node_type: Python type (default: str)
262
       Convert node ids to this type
263

264
    Returns
265
    -------
266
    graph: NetworkX graph
267
        If no parallel edges are found a Graph or DiGraph is returned.
268
        Otherwise a MultiGraph or MultiDiGraph is returned.
269

270
    Examples
271
    --------
272
    >>> G = nx.path_graph(4)
273
    >>> linefeed = chr(10)  # linefeed = \n
274
    >>> s = linefeed.join(nx.generate_graphml(G))
275
    >>> H = nx.parse_graphml(s)
276

277
    Notes
278
    -----
279
    Default node and edge attributes are not propagated to each node and edge.
280
    They can be obtained from `G.graph` and applied to node and edge attributes
281
    if desired using something like this:
282

283
    >>> default_color = G.graph['node_default']['color']  # doctest: +SKIP
284
    >>> for node, data in G.nodes(data=True):  # doctest: +SKIP
285
    ...    if 'color' not in data:
286
    ...        data['color']=default_color
287
    >>> default_color = G.graph['edge_default']['color']  # doctest: +SKIP
288
    >>> for u, v, data in G.edges(data=True):  # doctest: +SKIP
289
    ...    if 'color' not in data:
290
    ...        data['color']=default_color
291

292
    This implementation does not support mixed graphs (directed and unidirected
293
    edges together), hypergraphs, nested graphs, or ports.
294

295
    For multigraphs the GraphML edge "id" will be used as the edge
296
    key.  If not specified then they "key" attribute will be used.  If
297
    there is no "key" attribute a default NetworkX multigraph edge key
298
    will be provided.
299

300
    """
301
    reader = GraphMLReader(node_type=node_type)
302
    # need to check for multiple graphs
303
    glist = list(reader(string=graphml_string))
304
    if len(glist) == 0:
305
        # If no graph comes back, try looking for an incomplete header
306
        header = '<graphml xmlns="http://graphml.graphdrawing.org/xmlns">'
307
        new_string = graphml_string.replace('<graphml>', header)
308
        glist = list(reader(string=new_string))
309
        if len(glist) == 0:
310
            raise nx.NetworkXError('file not successfully read as graphml')
311
    return glist[0]
312

    
313

    
314
class GraphML(object):
315
    NS_GRAPHML = "http://graphml.graphdrawing.org/xmlns"
316
    NS_XSI = "http://www.w3.org/2001/XMLSchema-instance"
317
    # xmlns:y="http://www.yworks.com/xml/graphml"
318
    NS_Y = "http://www.yworks.com/xml/graphml"
319
    SCHEMALOCATION = \
320
        ' '.join(['http://graphml.graphdrawing.org/xmlns',
321
                  'http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd'])
322

    
323
    try:
324
        chr(12345)     # Fails on Py!=3.
325
        unicode = str  # Py3k's str is our unicode type
326
        long = int     # Py3K's int is our long type
327
    except ValueError:
328
        # Python 2.x
329
        pass
330

    
331
    types = [(int, "integer"),  # for Gephi GraphML bug
332
             (str, "yfiles"), (str, "string"), (unicode, "string"),
333
             (int, "int"), (long, "long"),
334
             (float, "float"), (float, "double"),
335
             (bool, "boolean")]
336

    
337
    # These additions to types allow writing numpy types
338
    try:
339
        import numpy as np
340
    except:
341
        pass
342
    else:
343
        # prepend so that python types are created upon read (last entry wins)
344
        types = [(np.float64, "float"), (np.float32, "float"),
345
                 (np.float16, "float"), (np.float_, "float"),
346
                 (np.int, "int"), (np.int8, "int"),
347
                 (np.int16, "int"), (np.int32, "int"),
348
                 (np.int64, "int"), (np.uint8, "int"),
349
                 (np.uint16, "int"), (np.uint32, "int"),
350
                 (np.uint64, "int"), (np.int_, "int"),
351
                 (np.intc, "int"), (np.intp, "int"),
352
                ] + types
353

    
354
    xml_type = dict(types)
355
    python_type = dict(reversed(a) for a in types)
356

    
357
    # This page says that data types in GraphML follow Java(TM).
358
    #  http://graphml.graphdrawing.org/primer/graphml-primer.html#AttributesDefinition
359
    # true and false are the only boolean literals:
360
    #  http://en.wikibooks.org/wiki/Java_Programming/Literals#Boolean_Literals
361
    convert_bool = {
362
        # We use data.lower() in actual use.
363
        'true': True, 'false': False,
364
        # Include integer strings for convenience.
365
        '0': False, 0: False,
366
        '1': True, 1: True
367
    }
368

    
369

    
370
class GraphMLWriter(GraphML):
371
    def __init__(self, graph=None, encoding="utf-8", prettyprint=True,
372
                 infer_numeric_types=False):
373
        try:
374
            import xml.etree.ElementTree
375
        except ImportError:
376
            msg = 'GraphML writer requires xml.elementtree.ElementTree'
377
            raise ImportError(msg)
378
        self.myElement = Element
379

    
380
        self.infer_numeric_types = infer_numeric_types
381
        self.prettyprint = prettyprint
382
        self.encoding = encoding
383
        self.xml = self.myElement("graphml",
384
                                  {'xmlns': self.NS_GRAPHML,
385
                                   'xmlns:xsi': self.NS_XSI,
386
                                   'xsi:schemaLocation': self.SCHEMALOCATION})
387
        self.keys = {}
388
        self.attributes = defaultdict(list)
389
        self.attribute_types = defaultdict(set)
390

    
391
        if graph is not None:
392
            self.add_graph_element(graph)
393

    
394
    def __str__(self):
395
        if self.prettyprint:
396
            self.indent(self.xml)
397
        s = tostring(self.xml).decode(self.encoding)
398
        return s
399

    
400
    def attr_type(self, name, scope, value):
401
        """Infer the attribute type of data named name. Currently this only
402
        supports inference of numeric types.
403

404
        If self.infer_numeric_types is false, type is used. Otherwise, pick the
405
        most general of types found across all values with name and scope. This
406
        means edges with data named 'weight' are treated separately from nodes
407
        with data named 'weight'.
408
        """
409
        if self.infer_numeric_types:
410
            types = self.attribute_types[(name, scope)]
411

    
412
            try:
413
                chr(12345)     # Fails on Py<3.
414
                local_long = int     # Py3's int is Py2's long type
415
                local_unicode = str  # Py3's str is Py2's unicode type
416
            except ValueError:
417
                # Python 2.x
418
                local_long = long
419
                local_unicode = unicode
420

    
421
            if len(types) > 1:
422
                if str in types:
423
                    return str
424
                elif local_unicode in types:
425
                    return local_unicode
426
                elif float in types:
427
                    return float
428
                elif local_long in types:
429
                    return local_long
430
                else:
431
                    return int
432
            else:
433
                return list(types)[0]
434
        else:
435
            return type(value)
436

    
437
    def get_key(self, name, attr_type, scope, default):
438
        keys_key = (name, attr_type, scope)
439
        try:
440
            return self.keys[keys_key]
441
        except KeyError:
442
            new_id = "d%i" % len(list(self.keys))
443
            self.keys[keys_key] = new_id
444
            key_kwargs = {"id": new_id,
445
                          "for": scope,
446
                          "attr.name": name,
447
                          "attr.type": attr_type}
448
            key_element = self.myElement("key", **key_kwargs)
449
            # add subelement for data default value if present
450
            if default is not None:
451
                default_element = self.myElement("default")
452
                default_element.text = make_str(default)
453
                key_element.append(default_element)
454
            self.xml.insert(0, key_element)
455
        return new_id
456

    
457
    def add_data(self, name, element_type, value,
458
                 scope="all",
459
                 default=None):
460
        """
461
        Make a data element for an edge or a node. Keep a log of the
462
        type in the keys table.
463
        """
464
        if element_type not in self.xml_type:
465
            msg = 'GraphML writer does not support %s as data values.'
466
            raise nx.NetworkXError(msg % element_type)
467
        keyid = self.get_key(name, self.xml_type[element_type], scope, default)
468
        data_element = self.myElement("data", key=keyid)
469
        data_element.text = make_str(value)
470
        return data_element
471

    
472
    def add_attributes(self, scope, xml_obj, data, default):
473
        """Appends attribute data to edges or nodes, and stores type information
474
        to be added later. See add_graph_element.
475
        """
476
        for k, v in data.items():
477
            self.attribute_types[(make_str(k), scope)].add(type(v))
478
            self.attributes[xml_obj].append([k, v, scope, default.get(k)])
479

    
480
    def add_nodes(self, G, graph_element):
481
        default = G.graph.get('node_default', {})
482
        for node, data in G.nodes(data=True):
483
            node_element = self.myElement("node", id=make_str(node))
484
            self.add_attributes("node", node_element, data, default)
485
            graph_element.append(node_element)
486

    
487
    def add_edges(self, G, graph_element):
488
        if G.is_multigraph():
489
            for u, v, key, data in G.edges(data=True, keys=True):
490
                edge_element = self.myElement("edge", source=make_str(u),
491
                                              target=make_str(v),
492
                                              id=make_str(key))
493
                default = G.graph.get('edge_default', {})
494
                self.add_attributes("edge", edge_element, data, default)
495
                graph_element.append(edge_element)
496
        else:
497
            for u, v, data in G.edges(data=True):
498
                edge_element = self.myElement("edge", source=make_str(u),
499
                                              target=make_str(v))
500
                default = G.graph.get('edge_default', {})
501
                self.add_attributes("edge", edge_element, data, default)
502
                graph_element.append(edge_element)
503

    
504
    def add_graph_element(self, G):
505
        """
506
        Serialize graph G in GraphML to the stream.
507
        """
508
        if G.is_directed():
509
            default_edge_type = 'directed'
510
        else:
511
            default_edge_type = 'undirected'
512

    
513
        graphid = G.graph.pop('id', None)
514
        if graphid is None:
515
            graph_element = self.myElement("graph",
516
                                           edgedefault=default_edge_type)
517
        else:
518
            graph_element = self.myElement("graph",
519
                                           edgedefault=default_edge_type,
520
                                           id=graphid)
521
        default = {}
522
        data = {k: v for (k, v) in G.graph.items()
523
                if k not in ['node_default', 'edge_default']}
524
        self.add_attributes("graph", graph_element, data, default)
525
        self.add_nodes(G, graph_element)
526
        self.add_edges(G, graph_element)
527

    
528
        # self.attributes contains a mapping from XML Objects to a list of
529
        # data that needs to be added to them.
530
        # We postpone processing in order to do type inference/generalization.
531
        # See self.attr_type
532
        for (xml_obj, data) in self.attributes.items():
533
            for (k, v, scope, default) in data:
534
                xml_obj.append(self.add_data(make_str(k),
535
                                             self.attr_type(k, scope, v),
536
                                             make_str(v), scope, default))
537
        self.xml.append(graph_element)
538

    
539
    def add_graphs(self, graph_list):
540
        """ Add many graphs to this GraphML document. """
541
        for G in graph_list:
542
            self.add_graph_element(G)
543

    
544
    def dump(self, stream):
545
        if self.prettyprint:
546
            self.indent(self.xml)
547
        document = ElementTree(self.xml)
548
        document.write(stream, encoding=self.encoding, xml_declaration=True)
549

    
550
    def indent(self, elem, level=0):
551
        # in-place prettyprint formatter
552
        i = "\n" + level * "  "
553
        if len(elem):
554
            if not elem.text or not elem.text.strip():
555
                elem.text = i + "  "
556
            if not elem.tail or not elem.tail.strip():
557
                elem.tail = i
558
            for elem in elem:
559
                self.indent(elem, level + 1)
560
            if not elem.tail or not elem.tail.strip():
561
                elem.tail = i
562
        else:
563
            if level and (not elem.tail or not elem.tail.strip()):
564
                elem.tail = i
565

    
566

    
567
class IncrementalElement(object):
568
    """Wrapper for _IncrementalWriter providing an Element like interface.
569

570
    This wrapper does not intend to be a complete implementation but rather to
571
    deal with those calls used in GraphMLWriter.
572
    """
573

    
574
    def __init__(self, xml, prettyprint):
575
        self.xml = xml
576
        self.prettyprint = prettyprint
577

    
578
    def append(self, element):
579
        self.xml.write(element, pretty_print=self.prettyprint)
580

    
581

    
582
class GraphMLWriterLxml(GraphMLWriter):
583
    def __init__(self, path, graph=None, encoding='utf-8', prettyprint=True,
584
                 infer_numeric_types=False):
585
        self.myElement = lxmletree.Element
586

    
587
        self._encoding = encoding
588
        self._prettyprint = prettyprint
589
        self.infer_numeric_types = infer_numeric_types
590

    
591
        self._xml_base = lxmletree.xmlfile(path, encoding=encoding)
592
        self._xml = self._xml_base.__enter__()
593
        self._xml.write_declaration()
594

    
595
        # We need to have a xml variable that support insertion. This call is
596
        # used for adding the keys to the document.
597
        # We will store those keys in a plain list, and then after the graph
598
        # element is closed we will add them to the main graphml element.
599
        self.xml = []
600
        self._keys = self.xml
601
        self._graphml = self._xml.element(
602
            'graphml',
603
            {
604
                'xmlns': self.NS_GRAPHML,
605
                'xmlns:xsi': self.NS_XSI,
606
                'xsi:schemaLocation': self.SCHEMALOCATION
607
            })
608
        self._graphml.__enter__()
609
        self.keys = {}
610
        self.attribute_types = defaultdict(set)
611

    
612
        if graph is not None:
613
            self.add_graph_element(graph)
614

    
615
    def add_graph_element(self, G):
616
        """
617
        Serialize graph G in GraphML to the stream.
618
        """
619
        if G.is_directed():
620
            default_edge_type = 'directed'
621
        else:
622
            default_edge_type = 'undirected'
623

    
624
        graphid = G.graph.pop('id', None)
625
        if graphid is None:
626
            graph_element = self._xml.element('graph',
627
                                              edgedefault=default_edge_type)
628
        else:
629
            graph_element = self._xml.element('graph',
630
                                              edgedefault=default_edge_type,
631
                                              id=graphid)
632

    
633
        # gather attributes types for the whole graph
634
        # to find the most general numeric format needed.
635
        # Then pass through attributes to create key_id for each.
636
        graphdata = {k: v for k, v in G.graph.items()
637
                     if k not in ('node_default', 'edge_default')}
638
        node_default = G.graph.get('node_default', {})
639
        edge_default = G.graph.get('edge_default', {})
640
        # Graph attributes
641
        for k, v in graphdata.items():
642
            self.attribute_types[(make_str(k), "graph")].add(type(v))
643
        for k, v in graphdata.items():
644
            element_type = self.xml_type[self.attr_type(k, "graph", v)]
645
            self.get_key(make_str(k), element_type, "graph", None)
646
        # Nodes and data
647
        attributes = {}
648
        for node, d in G.nodes(data=True):
649
            for k, v in d.items():
650
                self.attribute_types[(make_str(k), "node")].add(type(v))
651
                if k not in attributes:
652
                    attributes[k] = v
653
        for k, v in attributes.items():
654
            T = self.xml_type[self.attr_type(k, "node", v)]
655
            self.get_key(make_str(k), T, "node", node_default.get(k))
656
        # Edges and data
657
        if G.is_multigraph():
658
            attributes = {}
659
            for u, v, ekey, d in G.edges(keys=True, data=True):
660
                for k, v in d.items():
661
                    self.attribute_types[(make_str(k), "edge")].add(type(v))
662
                    if k not in attributes:
663
                        attributes[k] = v
664
            for k, v in attributes.items():
665
                T = self.xml_type[self.attr_type(k, "edge", v)]
666
                self.get_key(make_str(k), T, "edge", edge_default.get(k))
667
        else:
668
            attributes = {}
669
            for u, v, d in G.edges(data=True):
670
                for k, v in d.items():
671
                    self.attribute_types[(make_str(k), "edge")].add(type(v))
672
                    if k not in attributes:
673
                        attributes[k] = v
674
            for k, v in attributes.items():
675
                T = self.xml_type[self.attr_type(k, "edge", v)]
676
                self.get_key(make_str(k), T, "edge", edge_default.get(k))
677

    
678
        # Now add attribute keys to the xml file
679
        for key in self.xml:
680
            self._xml.write(key, pretty_print=self._prettyprint)
681

    
682
        # The incremental_writer writes each node/edge as it is created
683
        incremental_writer = IncrementalElement(self._xml, self._prettyprint)
684
        with graph_element:
685
            self.add_attributes('graph', incremental_writer, graphdata, {})
686
            self.add_nodes(G, incremental_writer)  # adds attributes too
687
            self.add_edges(G, incremental_writer)  # adds attributes too
688

    
689
    def add_attributes(self, scope, xml_obj, data, default):
690
        """Appends attribute data."""
691
        for k, v in data.items():
692
            data_element = self.add_data(make_str(k),
693
                                         self.attr_type(make_str(k), scope, v),
694
                                         make_str(v), scope, default.get(k))
695
            xml_obj.append(data_element)
696

    
697
    def __str__(self):
698
        return object.__str__(self)
699

    
700
    def dump(self):
701
        self._graphml.__exit__(None, None, None)
702
        self._xml_base.__exit__(None, None, None)
703

    
704

    
705
# Choose a writer function for default
706
if lxmletree is None:
707
    write_graphml = write_graphml_xml
708
else:
709
    write_graphml = write_graphml_lxml
710

    
711

    
712
class GraphMLReader(GraphML):
713
    """Read a GraphML document.  Produces NetworkX graph objects."""
714

    
715
    def __init__(self, node_type=str, edge_key_type=int):
716
        try:
717
            import xml.etree.ElementTree
718
        except ImportError:
719
            msg = 'GraphML reader requires xml.elementtree.ElementTree'
720
            raise ImportError(msg)
721
        self.node_type = node_type
722
        self.edge_key_type = edge_key_type
723
        self.multigraph = False  # assume multigraph and test for multiedges
724
        self.edge_ids = {}  # dict mapping (u,v) tuples to id edge attributes
725

    
726
    def __call__(self, path=None, string=None):
727
        if path is not None:
728
            self.xml = ElementTree(file=path)
729
        elif string is not None:
730
            self.xml = fromstring(string)
731
        else:
732
            raise ValueError("Must specify either 'path' or 'string' as kwarg")
733
        (keys, defaults) = self.find_graphml_keys(self.xml)
734
        for g in self.xml.findall("{%s}graph" % self.NS_GRAPHML):
735
            yield self.make_graph(g, keys, defaults)
736

    
737
    def make_graph(self, graph_xml, graphml_keys, defaults, G=None):
738
        # set default graph type
739
        edgedefault = graph_xml.get("edgedefault", None)
740
        if G is None:
741
            if edgedefault == 'directed':
742
                G = nx.MultiDiGraph()
743
            else:
744
                G = nx.MultiGraph()
745
        # set defaults for graph attributes
746
        G.graph['node_default'] = {}
747
        G.graph['edge_default'] = {}
748
        for key_id, value in defaults.items():
749
            key_for = graphml_keys[key_id]['for']
750
            name = graphml_keys[key_id]['name']
751
            python_type = graphml_keys[key_id]['type']
752
            if key_for == 'node':
753
                G.graph['node_default'].update({name: python_type(value)})
754
            if key_for == 'edge':
755
                G.graph['edge_default'].update({name: python_type(value)})
756
        # hyperedges are not supported
757
        hyperedge = graph_xml.find("{%s}hyperedge" % self.NS_GRAPHML)
758
        if hyperedge is not None:
759
            raise nx.NetworkXError("GraphML reader doesn't support hyperedges")
760
        # add nodes
761
        for node_xml in graph_xml.findall("{%s}node" % self.NS_GRAPHML):
762
            self.add_node(G, node_xml, graphml_keys, defaults)
763
        # add edges
764
        for edge_xml in graph_xml.findall("{%s}edge" % self.NS_GRAPHML):
765
            self.add_edge(G, edge_xml, graphml_keys)
766
        # add graph data
767
        data = self.decode_data_elements(graphml_keys, graph_xml)
768
        G.graph.update(data)
769

    
770
        # switch to Graph or DiGraph if no parallel edges were found.
771
        if not self.multigraph:
772
            if G.is_directed():
773
                G = nx.DiGraph(G)
774
            else:
775
                G = nx.Graph(G)
776
            nx.set_edge_attributes(G, values=self.edge_ids, name='id')
777

    
778
        return G
779

    
780
    def add_node(self, G, node_xml, graphml_keys, defaults):
781
        """Add a node to the graph.
782
        """
783
        # warn on finding unsupported ports tag
784
        ports = node_xml.find("{%s}port" % self.NS_GRAPHML)
785
        if ports is not None:
786
            warnings.warn("GraphML port tag not supported.")
787
        # find the node by id and cast it to the appropriate type
788
        node_id = self.node_type(node_xml.get("id"))
789
        # get data/attributes for node
790
        data = self.decode_data_elements(graphml_keys, node_xml)
791
        G.add_node(node_id, **data)
792
        # get child nodes
793
        if node_xml.attrib.get('yfiles.foldertype') == 'group':
794
            graph_xml = node_xml.find("{%s}graph" % self.NS_GRAPHML)
795
            self.make_graph(graph_xml, graphml_keys, defaults, G)
796

    
797
    def add_edge(self, G, edge_element, graphml_keys):
798
        """Add an edge to the graph.
799
        """
800
        # warn on finding unsupported ports tag
801
        ports = edge_element.find("{%s}port" % self.NS_GRAPHML)
802
        if ports is not None:
803
            warnings.warn("GraphML port tag not supported.")
804

    
805
        # raise error if we find mixed directed and undirected edges
806
        directed = edge_element.get("directed")
807
        if G.is_directed() and directed == 'false':
808
            msg = "directed=false edge found in directed graph."
809
            raise nx.NetworkXError(msg)
810
        if (not G.is_directed()) and directed == 'true':
811
            msg = "directed=true edge found in undirected graph."
812
            raise nx.NetworkXError(msg)
813

    
814
        source = self.node_type(edge_element.get("source"))
815
        target = self.node_type(edge_element.get("target"))
816
        data = self.decode_data_elements(graphml_keys, edge_element)
817
        # GraphML stores edge ids as an attribute
818
        # NetworkX uses them as keys in multigraphs too if no key
819
        # attribute is specified
820
        edge_id = edge_element.get("id")
821
        if edge_id:
822
            # self.edge_ids is used by `make_graph` method for non-multigraphs
823
            self.edge_ids[source, target] = edge_id
824
            try:
825
                edge_id = self.edge_key_type(edge_id)
826
            except ValueError:  # Could not convert.
827
                pass
828
        else:
829
            edge_id = data.get('key')
830

    
831
        if G.has_edge(source, target):
832
            # mark this as a multigraph
833
            self.multigraph = True
834

    
835
        # Use add_edges_from to avoid error with add_edge when `'key' in data`
836
        G.add_edges_from([(source, target, edge_id, data)])
837

    
838
    def decode_data_elements(self, graphml_keys, obj_xml):
839
        """Use the key information to decode the data XML if present."""
840
        data = {}
841
        for data_element in obj_xml.findall("{%s}data" % self.NS_GRAPHML):
842
            key = data_element.get("key")
843
            try:
844
                data_name = graphml_keys[key]['name']
845
                data_type = graphml_keys[key]['type']
846
            except KeyError:
847
                raise nx.NetworkXError("Bad GraphML data: no key %s" % key)
848
            text = data_element.text
849
            # assume anything with subelements is a yfiles extension
850
            if text is not None and len(list(data_element)) == 0:
851
                if data_type == bool:
852
                    # Ignore cases.
853
                    # http://docs.oracle.com/javase/6/docs/api/java/lang/
854
                    # Boolean.html#parseBoolean%28java.lang.String%29
855
                    data[data_name] = self.convert_bool[text.lower()]
856
                else:
857
                    data[data_name] = data_type(text)
858
            elif len(list(data_element)) > 0:
859
                # Assume yfiles as subelements, try to extract node_label
860
                node_label = None
861
                for node_type in ['ShapeNode', 'SVGNode', 'ImageNode']:
862
                    pref = "{%s}%s/{%s}" % (self.NS_Y, node_type, self.NS_Y)
863
                    geometry = data_element.find("%sGeometry" % pref)
864
                    if geometry is not None:
865
                        data['x'] = geometry.get('x')
866
                        data['y'] = geometry.get('y')
867
                    if node_label is None:
868
                        node_label = data_element.find("%sNodeLabel" % pref)
869
                if node_label is not None:
870
                    data['label'] = node_label.text
871

    
872
                # check all the different types of edges avaivable in yEd.
873
                for e in ['PolyLineEdge', 'SplineEdge', 'QuadCurveEdge',
874
                          'BezierEdge', 'ArcEdge']:
875
                    pref = "{%s}%s/{%s}" % (self.NS_Y, e, self.NS_Y)
876
                    edge_label = data_element.find("%sEdgeLabel" % pref)
877
                    if edge_label is not None:
878
                        break
879

    
880
                if edge_label is not None:
881
                    data['label'] = edge_label.text
882
        return data
883

    
884
    def find_graphml_keys(self, graph_element):
885
        """Extracts all the keys and key defaults from the xml.
886
        """
887
        graphml_keys = {}
888
        graphml_key_defaults = {}
889
        for k in graph_element.findall("{%s}key" % self.NS_GRAPHML):
890
            attr_id = k.get("id")
891
            attr_type = k.get('attr.type')
892
            attr_name = k.get("attr.name")
893
            yfiles_type = k.get("yfiles.type")
894
            if yfiles_type is not None:
895
                attr_name = yfiles_type
896
                attr_type = 'yfiles'
897
            if attr_type is None:
898
                attr_type = "string"
899
                warnings.warn("No key type for id %s. Using string" % attr_id)
900
            if attr_name is None:
901
                raise nx.NetworkXError("Unknown key for id %s." % attr_id)
902
            graphml_keys[attr_id] = {"name": attr_name,
903
                                     "type": self.python_type[attr_type],
904
                                     "for": k.get("for")}
905
            # check for "default" subelement of key element
906
            default = k.find("{%s}default" % self.NS_GRAPHML)
907
            if default is not None:
908
                graphml_key_defaults[attr_id] = default.text
909
        return graphml_keys, graphml_key_defaults
910

    
911

    
912
# fixture for nose tests
913
def setup_module(module):
914
    from nose import SkipTest
915
    try:
916
        import xml.etree.ElementTree
917
    except:
918
        raise SkipTest("xml.etree.ElementTree not available")
919

    
920

    
921
# fixture for nose tests
922
def teardown_module(module):
923
    import os
924
    try:
925
        os.unlink('test.graphml')
926
    except:
927
        pass