Files
biopython/Tests/test_PDB_KDTree.py
2025-02-13 12:11:31 +09:00

193 lines
7.6 KiB
Python

# Copyright 2009-2010 by Eric Talevich. All rights reserved.
# Revisions copyright 2010 by Peter Cock. All rights reserved.
#
# Converted by Eric Talevich from an older unit test copyright 2002
# by Thomas Hamelryck.
#
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.
"""Unit tests for those parts of the Bio.PDB module using Bio.PDB.kdtrees."""
import unittest
try:
from numpy import argsort
from numpy import array
from numpy import dot
from numpy import sqrt
from numpy.random import random
except ImportError:
from Bio import MissingExternalDependencyError
raise MissingExternalDependencyError(
"Install NumPy if you want to use Bio.PDB."
) from None
try:
from Bio.PDB import kdtrees
except ImportError:
from Bio import MissingExternalDependencyError
raise MissingExternalDependencyError(
"C module Bio.PDB.kdtrees not compiled"
) from None
from Bio.PDB.NeighborSearch import NeighborSearch
class NeighborTest(unittest.TestCase):
def test_neighbor_search(self):
"""NeighborSearch: Find nearby randomly generated coordinates.
Based on the self test in Bio.PDB.NeighborSearch.
"""
class RandomAtom:
def __init__(self):
self.coord = 100 * random(3)
def get_coord(self):
return self.coord
for i in range(20):
atoms = [RandomAtom() for j in range(100)]
ns = NeighborSearch(atoms)
hits = ns.search_all(5.0)
self.assertIsInstance(hits, list)
self.assertGreaterEqual(len(hits), 0)
x = array([250, 250, 250]) # Far away from our random atoms
self.assertEqual([], ns.search(x, 5.0, "A"))
self.assertEqual([], ns.search(x, 5.0, "R"))
self.assertEqual([], ns.search(x, 5.0, "C"))
self.assertEqual([], ns.search(x, 5.0, "M"))
self.assertEqual([], ns.search(x, 5.0, "S"))
class KDTreeTest(unittest.TestCase):
nr_points = 5000 # number of points used in test
bucket_size = 5 # number of points per tree node
radius = 0.05 # radius of search (typically 0.05 or so)
def test_KDTree_exceptions(self):
bucket_size = self.bucket_size
nr_points = self.nr_points
radius = self.radius
coords = random((nr_points, 3)) * 100000000000000
with self.assertRaises(Exception) as context:
kdt = kdtrees.KDTree(coords, bucket_size)
self.assertIn(
"coordinate values should lie between -1e6 and 1e6", str(context.exception)
)
with self.assertRaises(Exception) as context:
kdt = kdtrees.KDTree(random((nr_points, 3 - 2)), bucket_size)
self.assertIn("expected a Nx3 numpy array", str(context.exception))
def test_KDTree_point_search(self):
"""Test searching all points within a certain radius of center.
Using the kdtrees C module, search all point pairs that are
within radius, and compare the results to a manual search.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points
for radius in (self.radius, 100 * self.radius):
for i in range(10):
# kd tree search
coords = random((nr_points, 3))
center = random(3)
kdt = kdtrees.KDTree(coords, bucket_size)
points1 = kdt.search(center, radius)
points1.sort(key=lambda point: point.index) # noqa: E731
# manual search
points2 = []
for i in range(nr_points):
p = coords[i]
v = p - center
r = sqrt(dot(v, v))
if r <= radius:
point2 = kdtrees.Point(i, r)
points2.append(point2)
# compare results
self.assertEqual(len(points1), len(points2))
for point1, point2 in zip(points1, points2):
self.assertEqual(point1.index, point2.index)
self.assertAlmostEqual(point1.radius, point2.radius)
def test_KDTree_neighbor_search_simple(self):
"""Test all fixed radius neighbor search.
Test all fixed radius neighbor search using the KD tree C
module, and compare the results to those of a simple but
slow algorithm.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points
radius = self.radius
for i in range(10):
# KD tree search
coords = random((nr_points, 3))
kdt = kdtrees.KDTree(coords, bucket_size)
neighbors1 = kdt.neighbor_search(radius)
# same search, using a simple but slow algorithm
neighbors2 = kdt.neighbor_simple_search(radius)
# compare results
self.assertEqual(len(neighbors1), len(neighbors2))
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
neighbors1.sort(key=key)
neighbors2.sort(key=key)
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
self.assertEqual(neighbor1.index1, neighbor2.index1)
self.assertEqual(neighbor1.index2, neighbor2.index2)
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
def test_KDTree_neighbor_search_manual(self):
"""Test all fixed radius neighbor search.
Test all fixed radius neighbor search using the KD tree C
module, and compare the results to those of a manual search.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points // 10 # fewer points to speed up the test
for radius in (self.radius, 3 * self.radius):
for i in range(5):
# KD tree search
coords = random((nr_points, 3))
kdt = kdtrees.KDTree(coords, bucket_size)
neighbors1 = kdt.neighbor_search(radius)
# manual search
neighbors2 = []
indices = argsort(coords[:, 0])
for j1 in range(nr_points):
index1 = indices[j1]
p1 = coords[index1]
for j2 in range(j1 + 1, nr_points):
index2 = indices[j2]
p2 = coords[index2]
if p2[0] - p1[0] > radius:
break
v = p1 - p2
r = sqrt(dot(v, v))
if r <= radius:
if index1 < index2:
i1, i2 = index1, index2
else:
i1, i2 = index2, index1
neighbor = kdtrees.Neighbor(i1, i2, r)
neighbors2.append(neighbor)
self.assertEqual(len(neighbors1), len(neighbors2))
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
neighbors1.sort(key=key)
neighbors2.sort(key=key)
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
self.assertEqual(neighbor1.index1, neighbor2.index1)
self.assertEqual(neighbor1.index2, neighbor2.index2)
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
unittest.main(testRunner=runner)