mirror of
https://github.com/biopython/biopython.git
synced 2025-10-20 21:53:47 +08:00
193 lines
7.6 KiB
Python
193 lines
7.6 KiB
Python
# Copyright 2009-2010 by Eric Talevich. All rights reserved.
|
|
# Revisions copyright 2010 by Peter Cock. All rights reserved.
|
|
#
|
|
# Converted by Eric Talevich from an older unit test copyright 2002
|
|
# by Thomas Hamelryck.
|
|
#
|
|
# This file is part of the Biopython distribution and governed by your
|
|
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
|
|
# Please see the LICENSE file that should have been included as part of this
|
|
# package.
|
|
|
|
"""Unit tests for those parts of the Bio.PDB module using Bio.PDB.kdtrees."""
|
|
|
|
import unittest
|
|
|
|
try:
|
|
from numpy import argsort
|
|
from numpy import array
|
|
from numpy import dot
|
|
from numpy import sqrt
|
|
from numpy.random import random
|
|
except ImportError:
|
|
from Bio import MissingExternalDependencyError
|
|
|
|
raise MissingExternalDependencyError(
|
|
"Install NumPy if you want to use Bio.PDB."
|
|
) from None
|
|
|
|
try:
|
|
from Bio.PDB import kdtrees
|
|
except ImportError:
|
|
from Bio import MissingExternalDependencyError
|
|
|
|
raise MissingExternalDependencyError(
|
|
"C module Bio.PDB.kdtrees not compiled"
|
|
) from None
|
|
|
|
from Bio.PDB.NeighborSearch import NeighborSearch
|
|
|
|
|
|
class NeighborTest(unittest.TestCase):
|
|
def test_neighbor_search(self):
|
|
"""NeighborSearch: Find nearby randomly generated coordinates.
|
|
|
|
Based on the self test in Bio.PDB.NeighborSearch.
|
|
"""
|
|
|
|
class RandomAtom:
|
|
def __init__(self):
|
|
self.coord = 100 * random(3)
|
|
|
|
def get_coord(self):
|
|
return self.coord
|
|
|
|
for i in range(20):
|
|
atoms = [RandomAtom() for j in range(100)]
|
|
ns = NeighborSearch(atoms)
|
|
hits = ns.search_all(5.0)
|
|
self.assertIsInstance(hits, list)
|
|
self.assertGreaterEqual(len(hits), 0)
|
|
x = array([250, 250, 250]) # Far away from our random atoms
|
|
self.assertEqual([], ns.search(x, 5.0, "A"))
|
|
self.assertEqual([], ns.search(x, 5.0, "R"))
|
|
self.assertEqual([], ns.search(x, 5.0, "C"))
|
|
self.assertEqual([], ns.search(x, 5.0, "M"))
|
|
self.assertEqual([], ns.search(x, 5.0, "S"))
|
|
|
|
|
|
class KDTreeTest(unittest.TestCase):
|
|
nr_points = 5000 # number of points used in test
|
|
bucket_size = 5 # number of points per tree node
|
|
radius = 0.05 # radius of search (typically 0.05 or so)
|
|
|
|
def test_KDTree_exceptions(self):
|
|
bucket_size = self.bucket_size
|
|
nr_points = self.nr_points
|
|
radius = self.radius
|
|
coords = random((nr_points, 3)) * 100000000000000
|
|
with self.assertRaises(Exception) as context:
|
|
kdt = kdtrees.KDTree(coords, bucket_size)
|
|
self.assertIn(
|
|
"coordinate values should lie between -1e6 and 1e6", str(context.exception)
|
|
)
|
|
with self.assertRaises(Exception) as context:
|
|
kdt = kdtrees.KDTree(random((nr_points, 3 - 2)), bucket_size)
|
|
self.assertIn("expected a Nx3 numpy array", str(context.exception))
|
|
|
|
def test_KDTree_point_search(self):
|
|
"""Test searching all points within a certain radius of center.
|
|
|
|
Using the kdtrees C module, search all point pairs that are
|
|
within radius, and compare the results to a manual search.
|
|
"""
|
|
bucket_size = self.bucket_size
|
|
nr_points = self.nr_points
|
|
for radius in (self.radius, 100 * self.radius):
|
|
for i in range(10):
|
|
# kd tree search
|
|
coords = random((nr_points, 3))
|
|
center = random(3)
|
|
kdt = kdtrees.KDTree(coords, bucket_size)
|
|
points1 = kdt.search(center, radius)
|
|
points1.sort(key=lambda point: point.index) # noqa: E731
|
|
# manual search
|
|
points2 = []
|
|
for i in range(nr_points):
|
|
p = coords[i]
|
|
v = p - center
|
|
r = sqrt(dot(v, v))
|
|
if r <= radius:
|
|
point2 = kdtrees.Point(i, r)
|
|
points2.append(point2)
|
|
# compare results
|
|
self.assertEqual(len(points1), len(points2))
|
|
for point1, point2 in zip(points1, points2):
|
|
self.assertEqual(point1.index, point2.index)
|
|
self.assertAlmostEqual(point1.radius, point2.radius)
|
|
|
|
def test_KDTree_neighbor_search_simple(self):
|
|
"""Test all fixed radius neighbor search.
|
|
|
|
Test all fixed radius neighbor search using the KD tree C
|
|
module, and compare the results to those of a simple but
|
|
slow algorithm.
|
|
"""
|
|
bucket_size = self.bucket_size
|
|
nr_points = self.nr_points
|
|
radius = self.radius
|
|
for i in range(10):
|
|
# KD tree search
|
|
coords = random((nr_points, 3))
|
|
kdt = kdtrees.KDTree(coords, bucket_size)
|
|
neighbors1 = kdt.neighbor_search(radius)
|
|
# same search, using a simple but slow algorithm
|
|
neighbors2 = kdt.neighbor_simple_search(radius)
|
|
# compare results
|
|
self.assertEqual(len(neighbors1), len(neighbors2))
|
|
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
|
|
neighbors1.sort(key=key)
|
|
neighbors2.sort(key=key)
|
|
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
|
|
self.assertEqual(neighbor1.index1, neighbor2.index1)
|
|
self.assertEqual(neighbor1.index2, neighbor2.index2)
|
|
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
|
|
|
|
def test_KDTree_neighbor_search_manual(self):
|
|
"""Test all fixed radius neighbor search.
|
|
|
|
Test all fixed radius neighbor search using the KD tree C
|
|
module, and compare the results to those of a manual search.
|
|
"""
|
|
bucket_size = self.bucket_size
|
|
nr_points = self.nr_points // 10 # fewer points to speed up the test
|
|
for radius in (self.radius, 3 * self.radius):
|
|
for i in range(5):
|
|
# KD tree search
|
|
coords = random((nr_points, 3))
|
|
kdt = kdtrees.KDTree(coords, bucket_size)
|
|
neighbors1 = kdt.neighbor_search(radius)
|
|
# manual search
|
|
neighbors2 = []
|
|
indices = argsort(coords[:, 0])
|
|
for j1 in range(nr_points):
|
|
index1 = indices[j1]
|
|
p1 = coords[index1]
|
|
for j2 in range(j1 + 1, nr_points):
|
|
index2 = indices[j2]
|
|
p2 = coords[index2]
|
|
if p2[0] - p1[0] > radius:
|
|
break
|
|
v = p1 - p2
|
|
r = sqrt(dot(v, v))
|
|
if r <= radius:
|
|
if index1 < index2:
|
|
i1, i2 = index1, index2
|
|
else:
|
|
i1, i2 = index2, index1
|
|
neighbor = kdtrees.Neighbor(i1, i2, r)
|
|
neighbors2.append(neighbor)
|
|
self.assertEqual(len(neighbors1), len(neighbors2))
|
|
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
|
|
neighbors1.sort(key=key)
|
|
neighbors2.sort(key=key)
|
|
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
|
|
self.assertEqual(neighbor1.index1, neighbor2.index1)
|
|
self.assertEqual(neighbor1.index2, neighbor2.index2)
|
|
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
unittest.main(testRunner=runner)
|