Statistics
| Branch: | Revision:

iof-tools / networkxMiCe / networkx-master / networkx / readwrite / tests / test_shp.py @ 5cef0f13

History | View | Annotate | Download (9.72 KB)

1
"""Unit tests for shp.
2
"""
3

    
4
import os
5
import tempfile
6
from nose import SkipTest
7
from nose.tools import assert_equal
8
from nose.tools import raises
9

    
10
import networkx as nx
11

    
12

    
13
class TestShp(object):
14
    @classmethod
15
    def setupClass(cls):
16
        global ogr
17
        try:
18
            from osgeo import ogr
19
        except ImportError:
20
            raise SkipTest('ogr not available.')
21

    
22
    def deletetmp(self, drv, *paths):
23
        for p in paths:
24
            if os.path.exists(p):
25
                drv.DeleteDataSource(p)
26

    
27
    def setUp(self):
28

    
29
        def createlayer(driver, layerType=ogr.wkbLineString):
30
            lyr = driver.CreateLayer("edges", None, layerType)
31
            namedef = ogr.FieldDefn("Name", ogr.OFTString)
32
            namedef.SetWidth(32)
33
            lyr.CreateField(namedef)
34
            return lyr
35

    
36
        drv = ogr.GetDriverByName("ESRI Shapefile")
37

    
38
        testdir = os.path.join(tempfile.gettempdir(), 'shpdir')
39
        shppath = os.path.join(tempfile.gettempdir(), 'tmpshp.shp')
40
        multi_shppath = os.path.join(tempfile.gettempdir(), 'tmp_mshp.shp')
41

    
42
        self.deletetmp(drv, testdir, shppath, multi_shppath)
43
        os.mkdir(testdir)
44

    
45
        self.names = ['a', 'b', 'c', 'c']  # edgenames
46
        self.paths = ([(1.0, 1.0), (2.0, 2.0)],
47
                      [(2.0, 2.0), (3.0, 3.0)],
48
                      [(0.9, 0.9), (4.0, 0.9), (4.0, 2.0)])
49

    
50
        self.simplified_names = ['a', 'b', 'c']  # edgenames
51
        self.simplified_paths = ([(1.0, 1.0), (2.0, 2.0)],
52
                                 [(2.0, 2.0), (3.0, 3.0)],
53
                                 [(0.9, 0.9), (4.0, 2.0)])
54

    
55
        self.multi_names = ['a', 'a', 'a', 'a']  # edgenames
56

    
57
        shp = drv.CreateDataSource(shppath)
58
        lyr = createlayer(shp)
59

    
60
        for path, name in zip(self.paths, self.names):
61
            feat = ogr.Feature(lyr.GetLayerDefn())
62
            g = ogr.Geometry(ogr.wkbLineString)
63
            for p in path:
64
                g.AddPoint_2D(*p)
65
            feat.SetGeometry(g)
66
            feat.SetField("Name", name)
67
            lyr.CreateFeature(feat)
68

    
69
        # create single record multiline shapefile for testing
70
        multi_shp = drv.CreateDataSource(multi_shppath)
71
        multi_lyr = createlayer(multi_shp, ogr.wkbMultiLineString)
72

    
73
        multi_g = ogr.Geometry(ogr.wkbMultiLineString)
74
        for path in self.paths:
75

    
76
            g = ogr.Geometry(ogr.wkbLineString)
77
            for p in path:
78
                g.AddPoint_2D(*p)
79

    
80
            multi_g.AddGeometry(g)
81

    
82
        multi_feat = ogr.Feature(multi_lyr.GetLayerDefn())
83
        multi_feat.SetGeometry(multi_g)
84
        multi_feat.SetField("Name", 'a')
85
        multi_lyr.CreateFeature(multi_feat)
86

    
87
        self.shppath = shppath
88
        self.multi_shppath = multi_shppath
89
        self.testdir = testdir
90
        self.drv = drv
91

    
92
    def testload(self):
93

    
94
        def compare_graph_paths_names(g, paths, names):
95
            expected = nx.DiGraph()
96
            for p in paths:
97
                nx.add_path(expected, p)
98
            assert_equal(sorted(expected.nodes), sorted(g.nodes))
99
            assert_equal(sorted(expected.edges()), sorted(g.edges()))
100
            g_names = [g.get_edge_data(s, e)['Name'] for s, e in g.edges()]
101
            assert_equal(names, sorted(g_names))
102

    
103
        # simplified
104
        G = nx.read_shp(self.shppath)
105
        compare_graph_paths_names(G, self.simplified_paths,
106
                                  self.simplified_names)
107

    
108
        # unsimplified
109
        G = nx.read_shp(self.shppath, simplify=False)
110
        compare_graph_paths_names(G, self.paths, self.names)
111

    
112
        # multiline unsimplified
113
        G = nx.read_shp(self.multi_shppath, simplify=False)
114
        compare_graph_paths_names(G, self.paths, self.multi_names)
115

    
116
    def checkgeom(self, lyr, expected):
117
        feature = lyr.GetNextFeature()
118
        actualwkt = []
119
        while feature:
120
            actualwkt.append(feature.GetGeometryRef().ExportToWkt())
121
            feature = lyr.GetNextFeature()
122
        assert_equal(sorted(expected), sorted(actualwkt))
123

    
124
    def test_geometryexport(self):
125
        expectedpoints_simple = (
126
            "POINT (1 1)",
127
            "POINT (2 2)",
128
            "POINT (3 3)",
129
            "POINT (0.9 0.9)",
130
            "POINT (4 2)"
131
        )
132
        expectedlines_simple = (
133
            "LINESTRING (1 1,2 2)",
134
            "LINESTRING (2 2,3 3)",
135
            "LINESTRING (0.9 0.9,4.0 0.9,4 2)"
136
        )
137
        expectedpoints = (
138
            "POINT (1 1)",
139
            "POINT (2 2)",
140
            "POINT (3 3)",
141
            "POINT (0.9 0.9)",
142
            "POINT (4.0 0.9)",
143
            "POINT (4 2)"
144
        )
145
        expectedlines = (
146
            "LINESTRING (1 1,2 2)",
147
            "LINESTRING (2 2,3 3)",
148
            "LINESTRING (0.9 0.9,4.0 0.9)",
149
            "LINESTRING (4.0 0.9,4 2)"
150
        )
151

    
152
        tpath = os.path.join(tempfile.gettempdir(), 'shpdir')
153
        G = nx.read_shp(self.shppath)
154
        nx.write_shp(G, tpath)
155
        shpdir = ogr.Open(tpath)
156
        self.checkgeom(shpdir.GetLayerByName("nodes"), expectedpoints_simple)
157
        self.checkgeom(shpdir.GetLayerByName("edges"), expectedlines_simple)
158

    
159
        # Test unsimplified
160
        # Nodes should have additional point,
161
        # edges should be 'flattened'
162
        G = nx.read_shp(self.shppath, simplify=False)
163
        nx.write_shp(G, tpath)
164
        shpdir = ogr.Open(tpath)
165
        self.checkgeom(shpdir.GetLayerByName("nodes"), expectedpoints)
166
        self.checkgeom(shpdir.GetLayerByName("edges"), expectedlines)
167

    
168
    def test_attributeexport(self):
169
        def testattributes(lyr, graph):
170
            feature = lyr.GetNextFeature()
171
            while feature:
172
                coords = []
173
                ref = feature.GetGeometryRef()
174
                last = ref.GetPointCount() - 1
175
                edge_nodes = (ref.GetPoint_2D(0), ref.GetPoint_2D(last))
176
                name = feature.GetFieldAsString('Name')
177
                assert_equal(graph.get_edge_data(*edge_nodes)['Name'], name)
178
                feature = lyr.GetNextFeature()
179

    
180
        tpath = os.path.join(tempfile.gettempdir(), 'shpdir')
181

    
182
        G = nx.read_shp(self.shppath)
183
        nx.write_shp(G, tpath)
184
        shpdir = ogr.Open(tpath)
185
        edges = shpdir.GetLayerByName("edges")
186
        testattributes(edges, G)
187

    
188
    # Test export of node attributes in nx.write_shp (#2778)
189
    def test_nodeattributeexport(self):
190
        tpath = os.path.join(tempfile.gettempdir(), 'shpdir')
191

    
192
        G = nx.DiGraph()
193
        A = (0, 0)
194
        B = (1, 1)
195
        C = (2, 2)
196
        G.add_edge(A, B)
197
        G.add_edge(A, C)
198
        label = 'node_label'
199
        for n, d in G.nodes(data=True):
200
            d['label'] = label
201
        nx.write_shp(G, tpath)
202

    
203
        H = nx.read_shp(tpath)
204
        for n, d in H.nodes(data=True):
205
            assert_equal(d['label'], label)
206

    
207
    def test_wkt_export(self):
208
        G = nx.DiGraph()
209
        tpath = os.path.join(tempfile.gettempdir(), 'shpdir')
210
        points = (
211
            "POINT (0.9 0.9)",
212
            "POINT (4 2)"
213
        )
214
        line = (
215
            "LINESTRING (0.9 0.9,4 2)",
216
        )
217
        G.add_node(1, Wkt=points[0])
218
        G.add_node(2, Wkt=points[1])
219
        G.add_edge(1, 2, Wkt=line[0])
220
        try:
221
            nx.write_shp(G, tpath)
222
        except Exception as e:
223
            assert False, e
224
        shpdir = ogr.Open(tpath)
225
        self.checkgeom(shpdir.GetLayerByName("nodes"), points)
226
        self.checkgeom(shpdir.GetLayerByName("edges"), line)
227

    
228
    def tearDown(self):
229
        self.deletetmp(self.drv, self.testdir, self.shppath)
230

    
231

    
232
@raises(RuntimeError)
233
def test_read_shp_nofile():
234
    try:
235
        from osgeo import ogr
236
    except ImportError:
237
        raise SkipTest('ogr not available.')
238
    G = nx.read_shp("hopefully_this_file_will_not_be_available")
239

    
240

    
241
class TestMissingGeometry(object):
242
    @classmethod
243
    def setup_class(cls):
244
        global ogr
245
        try:
246
            from osgeo import ogr
247
        except ImportError:
248
            raise SkipTest('ogr not available.')
249

    
250
    def setUp(self):
251
        self.setup_path()
252
        self.delete_shapedir()
253
        self.create_shapedir()
254

    
255
    def tearDown(self):
256
        self.delete_shapedir()
257

    
258
    def setup_path(self):
259
        self.path = os.path.join(tempfile.gettempdir(), 'missing_geometry')
260

    
261
    def create_shapedir(self):
262
        drv = ogr.GetDriverByName("ESRI Shapefile")
263
        shp = drv.CreateDataSource(self.path)
264
        lyr = shp.CreateLayer("nodes", None, ogr.wkbPoint)
265
        feature = ogr.Feature(lyr.GetLayerDefn())
266
        feature.SetGeometry(None)
267
        lyr.CreateFeature(feature)
268
        feature.Destroy()
269

    
270
    def delete_shapedir(self):
271
        drv = ogr.GetDriverByName("ESRI Shapefile")
272
        if os.path.exists(self.path):
273
            drv.DeleteDataSource(self.path)
274

    
275
    @raises(nx.NetworkXError)
276
    def test_missing_geometry(self):
277
        G = nx.read_shp(self.path)
278

    
279

    
280
class TestMissingAttrWrite(object):
281
    @classmethod
282
    def setup_class(cls):
283
        global ogr
284
        try:
285
            from osgeo import ogr
286
        except ImportError:
287
            raise SkipTest('ogr not available.')
288

    
289
    def setUp(self):
290
        self.setup_path()
291
        self.delete_shapedir()
292

    
293
    def tearDown(self):
294
        self.delete_shapedir()
295

    
296
    def setup_path(self):
297
        self.path = os.path.join(tempfile.gettempdir(), 'missing_attributes')
298

    
299
    def delete_shapedir(self):
300
        drv = ogr.GetDriverByName("ESRI Shapefile")
301
        if os.path.exists(self.path):
302
            drv.DeleteDataSource(self.path)
303

    
304
    def test_missing_attributes(self):
305
        G = nx.DiGraph()
306
        A = (0, 0)
307
        B = (1, 1)
308
        C = (2, 2)
309
        G.add_edge(A, B, foo=100)
310
        G.add_edge(A, C)
311

    
312
        nx.write_shp(G, self.path)
313
        H = nx.read_shp(self.path)
314

    
315
        for u, v, d in H.edges(data=True):
316
            if u == A and v == B:
317
                assert_equal(d['foo'], 100)
318
            if u == A and v == C:
319
                assert_equal(d['foo'], None)