Revision 1ef4948a util/UnitDiskGraph.py

View differences:

util/UnitDiskGraph.py
1 1
# https://stackoverflow.com/questions/32424604/find-all-nearest-neighbors-within-a-specific-distance
2 2

  
3 3
import networkx as nx
4
from scipy import spatial
5
import code # code.interact(local=dict(globals(), **locals()))
6

  
4
from scipy.spatial import KDTree
5
from scipy.spatial.distance import cdist
6
import code  # code.interact(local=dict(globals(), **locals()))
7
import graph_tool as gt
7 8

  
8 9
class UnitDiskGraph:
9 10

  
10 11
    def __init__(self, points, radius):
11
        self.G = self.genereateGraphFromKDtree(points, radius)
12

  
12
        self.G = nx.Graph()
13
        self.G.add_nodes_from(range(len(points)))
14
        '''
15
        self.G = gt.Graph(directed=False)
16
        self.G.add_vertex(len(points))
17
        self.edge_weights = self.G.new_edge_property('double')
18
        self.G.edge_properties['weight'] = self.edge_weights'''
19
        
20
        self.generateGraphCDIST(points, radius)
21
        
22
        
13 23
    def genereateGraphFromKDtree(self, points, radius):
14
        tree = spatial.KDTree(points)
24
        tree = KDTree(points)
15 25
        edges = tree.query_pairs(r=radius)
16
        edges = [e+(1.0,) for e in edges]
17
        G = nx.Graph()
18
        #pos = {k:points[k] for k in range(0,len(points))}
19
        #code.interact(local=dict(globals(), **locals()))
20
        G.add_nodes_from(range(len(points)))
21
        G.add_weighted_edges_from(edges, weight='weight')
22
        return G
26
        #edges = [e+(1.0,) for e in edges]
27
        #pos = {k:points[k] for k in range(0,len(points))}        
28
        #self.G.add_weighted_edges_from(edges, weight='weight')
29
        for e in edges:
30
            e = self.G.add_edge(e[0],e[1])
31
            self.edge_weights[e] = 1.0
32

  
33

  
34
    def generateGraphCDIST(self, points, radius):
35
        distM = cdist(points, points, 'euclidean')
36
        edges = []
37
        for r in range(len(points)):
38
            for c in range(len(points)):
39
                if r==c:
40
                    continue
41
                if distM[r][c] <= radius:
42
                    edges.append((r,c,1.0))
43
        self.G.add_weighted_edges_from(edges, weight='weight')
44
        '''code.interact(local=dict(globals(), **locals()))
45
        for e in edges:
46
            e = self.G.add_edge(e[0],e[1])
47
            self.edge_weights[e] = 1.0'''
23 48

  
24 49
    def getGraph(self):
25 50
        return self.G

Also available in: Unified diff