Statistics
| Branch: | Revision:

iof-tools / networkxMiCe / networkx-master / networkx / utils / tests / test_decorators.py @ 5cef0f13

History | View | Annotate | Download (8.51 KB)

1
import tempfile
2
import os
3
import random
4

    
5
from nose.tools import *
6
from nose import SkipTest
7

    
8
import networkx as nx
9
from networkx.utils.decorators import open_file, not_implemented_for
10
from networkx.utils.decorators import nodes_or_number, preserve_random_state, \
11
    py_random_state, np_random_state, random_state
12
from networkx.utils.misc import PythonRandomInterface
13

    
14
def test_not_implemented_decorator():
15
    @not_implemented_for('directed')
16
    def test1(G):
17
        pass
18
    test1(nx.Graph())
19

    
20

    
21
@raises(KeyError)
22
def test_not_implemented_decorator_key():
23
    @not_implemented_for('foo')
24
    def test1(G):
25
        pass
26
    test1(nx.Graph())
27

    
28

    
29
@raises(nx.NetworkXNotImplemented)
30
def test_not_implemented_decorator_raise():
31
    @not_implemented_for('graph')
32
    def test1(G):
33
        pass
34
    test1(nx.Graph())
35

    
36

    
37
class TestOpenFileDecorator(object):
38
    def setUp(self):
39
        self.text = ['Blah... ', 'BLAH ', 'BLAH!!!!']
40
        self.fobj = tempfile.NamedTemporaryFile('wb+', delete=False)
41
        self.name = self.fobj.name
42

    
43
    def write(self, path):
44
        for text in self.text:
45
            path.write(text.encode('ascii'))
46

    
47
    @open_file(1, 'r')
48
    def read(self, path):
49
        return path.readlines()[0]
50

    
51
    @staticmethod
52
    @open_file(0, 'wb')
53
    def writer_arg0(path):
54
        path.write('demo'.encode('ascii'))
55

    
56
    @open_file(1, 'wb+')
57
    def writer_arg1(self, path):
58
        self.write(path)
59

    
60
    @open_file(2, 'wb')
61
    def writer_arg2default(self, x, path=None):
62
        if path is None:
63
            with tempfile.NamedTemporaryFile('wb+') as fh:
64
                self.write(fh)
65
        else:
66
            self.write(path)
67

    
68
    @open_file(4, 'wb')
69
    def writer_arg4default(self, x, y, other='hello', path=None, **kwargs):
70
        if path is None:
71
            with tempfile.NamedTemporaryFile('wb+') as fh:
72
                self.write(fh)
73
        else:
74
            self.write(path)
75

    
76
    @open_file('path', 'wb')
77
    def writer_kwarg(self, **kwargs):
78
        path = kwargs.get('path', None)
79
        if path is None:
80
            with tempfile.NamedTemporaryFile('wb+') as fh:
81
                self.write(fh)
82
        else:
83
            self.write(path)
84

    
85
    def test_writer_arg0_str(self):
86
        self.writer_arg0(self.name)
87

    
88
    def test_writer_arg0_fobj(self):
89
        self.writer_arg0(self.fobj)
90

    
91
    def test_writer_arg0_pathlib(self):
92
        try:
93
            import pathlib
94
            self.writer_arg0(pathlib.Path(self.name))
95
        except ImportError:
96
            return
97

    
98
    def test_writer_arg1_str(self):
99
        self.writer_arg1(self.name)
100
        assert_equal(self.read(self.name), ''.join(self.text))
101

    
102
    def test_writer_arg1_fobj(self):
103
        self.writer_arg1(self.fobj)
104
        assert_false(self.fobj.closed)
105
        self.fobj.close()
106
        assert_equal(self.read(self.name), ''.join(self.text))
107

    
108
    def test_writer_arg2default_str(self):
109
        self.writer_arg2default(0, path=None)
110
        self.writer_arg2default(0, path=self.name)
111
        assert_equal(self.read(self.name), ''.join(self.text))
112

    
113
    def test_writer_arg2default_fobj(self):
114
        self.writer_arg2default(0, path=self.fobj)
115
        assert_false(self.fobj.closed)
116
        self.fobj.close()
117
        assert_equal(self.read(self.name), ''.join(self.text))
118

    
119
    def test_writer_arg2default_fobj_path_none(self):
120
        self.writer_arg2default(0, path=None)
121

    
122
    def test_writer_arg4default_fobj(self):
123
        self.writer_arg4default(0, 1, dog='dog', other='other')
124
        self.writer_arg4default(0, 1, dog='dog', other='other', path=self.name)
125
        assert_equal(self.read(self.name), ''.join(self.text))
126

    
127
    def test_writer_kwarg_str(self):
128
        self.writer_kwarg(path=self.name)
129
        assert_equal(self.read(self.name), ''.join(self.text))
130

    
131
    def test_writer_kwarg_fobj(self):
132
        self.writer_kwarg(path=self.fobj)
133
        self.fobj.close()
134
        assert_equal(self.read(self.name), ''.join(self.text))
135

    
136
    def test_writer_kwarg_path_none(self):
137
        self.writer_kwarg(path=None)
138

    
139
    def tearDown(self):
140
        self.fobj.close()
141
        os.unlink(self.name)
142

    
143

    
144
@preserve_random_state
145
def test_preserve_random_state():
146
    try:
147
        import numpy.random
148
        r = numpy.random.random()
149
    except ImportError:
150
        return
151
    assert(abs(r - 0.61879477158568) < 1e-16)
152

    
153

    
154
class TestRandomState(object):
155
    @classmethod
156
    def setUp(cls):
157
        global np
158
        try:
159
            import numpy as np
160
        except ImportError:
161
            raise SkipTest('NumPy not available.')
162

    
163
    @random_state(1)
164
    def instantiate_random_state(self, random_state):
165
        assert_true(isinstance(random_state, np.random.RandomState))
166
        return random_state.random_sample()
167

    
168
    @np_random_state(1)
169
    def instantiate_np_random_state(self, random_state):
170
        assert_true(isinstance(random_state, np.random.RandomState))
171
        return random_state.random_sample()
172

    
173
    @py_random_state(1)
174
    def instantiate_py_random_state(self, random_state):
175
        assert_true(isinstance(random_state, random.Random) or
176
                    isinstance(random_state, PythonRandomInterface))
177
        return random_state.random()
178

    
179
    def test_random_state_None(self):
180
        np.random.seed(42)
181
        rv = np.random.random_sample()
182
        np.random.seed(42)
183
        assert_equal(rv, self.instantiate_random_state(None))
184
        np.random.seed(42)
185
        assert_equal(rv, self.instantiate_np_random_state(None))
186

    
187
        random.seed(42)
188
        rv = random.random()
189
        random.seed(42)
190
        assert_equal(rv, self.instantiate_py_random_state(None))
191

    
192
    def test_random_state_np_random(self):
193
        np.random.seed(42)
194
        rv = np.random.random_sample()
195
        np.random.seed(42)
196
        assert_equal(rv, self.instantiate_random_state(np.random))
197
        np.random.seed(42)
198
        assert_equal(rv, self.instantiate_np_random_state(np.random))
199
        np.random.seed(42)
200
        assert_equal(rv, self.instantiate_py_random_state(np.random))
201

    
202
    def test_random_state_int(self):
203
        np.random.seed(42)
204
        np_rv = np.random.random_sample()
205
        random.seed(42)
206
        py_rv = random.random()
207

    
208
        np.random.seed(42)
209
        seed = 1
210
        rval = self.instantiate_random_state(seed)
211
        rval_expected = np.random.RandomState(seed).rand()
212
        assert_true(rval, rval_expected)
213

    
214
        rval = self.instantiate_np_random_state(seed)
215
        rval_expected = np.random.RandomState(seed).rand()
216
        assert_true(rval, rval_expected)
217
        # test that global seed wasn't changed in function
218
        assert_equal(np_rv, np.random.random_sample())
219

    
220
        random.seed(42)
221
        rval = self.instantiate_py_random_state(seed)
222
        rval_expected = random.Random(seed).random()
223
        assert_true(rval, rval_expected)
224
        # test that global seed wasn't changed in function
225
        assert_equal(py_rv, random.random())
226

    
227
    def test_random_state_np_random_RandomState(self):
228
        np.random.seed(42)
229
        np_rv = np.random.random_sample()
230

    
231
        np.random.seed(42)
232
        seed = 1
233
        rng = np.random.RandomState(seed)
234
        rval = self.instantiate_random_state(rng)
235
        rval_expected = np.random.RandomState(seed).rand()
236
        assert_true(rval, rval_expected)
237

    
238
        rval = self.instantiate_np_random_state(seed)
239
        rval_expected = np.random.RandomState(seed).rand()
240
        assert_true(rval, rval_expected)
241

    
242
        rval = self.instantiate_py_random_state(seed)
243
        rval_expected = np.random.RandomState(seed).rand()
244
        assert_true(rval, rval_expected)
245
        # test that global seed wasn't changed in function
246
        assert_equal(np_rv, np.random.random_sample())
247

    
248
    def test_random_state_py_random(self):
249
        seed = 1
250
        rng = random.Random(seed)
251
        rv = self.instantiate_py_random_state(rng)
252
        assert_true(rv, random.Random(seed).random())
253

    
254
        assert_raises(ValueError, self.instantiate_random_state, rng)
255
        assert_raises(ValueError, self.instantiate_np_random_state, rng)
256

    
257

    
258
@raises(nx.NetworkXError)
259
def test_random_state_string_arg_index():
260
    @random_state('a')
261
    def make_random_state(rs):
262
        pass
263
    rstate = make_random_state(1)
264

    
265

    
266
@raises(nx.NetworkXError)
267
def test_py_random_state_string_arg_index():
268
    @py_random_state('a')
269
    def make_random_state(rs):
270
        pass
271
    rstate = make_random_state(1)
272

    
273

    
274
@raises(nx.NetworkXError)
275
def test_random_state_invalid_arg_index():
276
    @random_state(2)
277
    def make_random_state(rs):
278
        pass
279
    rstate = make_random_state(1)
280

    
281

    
282
@raises(nx.NetworkXError)
283
def test_py_random_state_invalid_arg_index():
284
    @py_random_state(2)
285
    def make_random_state(rs):
286
        pass
287
    rstate = make_random_state(1)