Skip to content
Snippets Groups Projects
Commit 8814b60d authored by giagitom's avatar giagitom
Browse files

Increased performances of lookup3d nearest search.

parent 813c4f84
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@
* Copyright (c) 2018-2020 mol* contributors, licensed under MIT, See LICENSE file for more info.
*
* @author David Sehnal <david.sehnal@gmail.com>
* @author Gianluca Tomasello <giagitom@gmail.com>
*/
import { GridLookup3D } from '../../geometry';
......@@ -24,9 +25,17 @@ describe('GridLookup3d', () => {
expect(r.count).toBe(1);
expect(r.indices[0]).toBe(0);
r = grid.nearest(0, 0, 0, 1);
expect(r.count).toBe(1);
expect(r.indices[0]).toBe(0);
r = grid.find(0, 0, 0, 1);
expect(r.count).toBe(3);
expect(sortArray(r.indices)).toEqual([0, 1, 2]);
r = grid.nearest(0, 0, 0, 3);
expect(r.count).toBe(3);
expect(sortArray(r.indices)).toEqual([0, 1, 2]);
});
it('radius', () => {
......@@ -38,9 +47,17 @@ describe('GridLookup3d', () => {
expect(r.count).toBe(1);
expect(r.indices[0]).toBe(0);
r = grid.nearest(0, 0, 0, 1);
expect(r.count).toBe(1);
expect(r.indices[0]).toBe(0);
r = grid.find(0, 0, 0, 0.5);
expect(r.count).toBe(2);
expect(sortArray(r.indices)).toEqual([0, 1]);
r = grid.nearest(0, 0, 0, 3);
expect(r.count).toBe(3);
expect(sortArray(r.indices)).toEqual([0, 1, 2]);
});
it('indexed', () => {
......@@ -51,8 +68,15 @@ describe('GridLookup3d', () => {
let r = grid.find(0, 0, 0, 0);
expect(r.count).toBe(0);
r = grid.nearest(0, 0, 0, 1);
expect(r.count).toBe(1);
r = grid.find(0, 0, 0, 0.5);
expect(r.count).toBe(1);
expect(sortArray(r.indices)).toEqual([0]);
r = grid.nearest(0, 0, 0, 3);
expect(r.count).toBe(1);
expect(sortArray(r.indices)).toEqual([0]);
});
});
\ No newline at end of file
});
......@@ -41,7 +41,7 @@ export namespace Result {
export interface Lookup3D<T = number> {
// The result is mutated with each call to find.
find(x: number, y: number, z: number, radius: number, result?: Result<T>): Result<T>,
nearest(x: number, y: number, z: number, k: number, result?: Result<T>): Result<T>,
nearest(x: number, y: number, z: number, k: number, stopIf?: Function, result?: Result<T>): Result<T>,
check(x: number, y: number, z: number, radius: number): boolean,
readonly boundary: { readonly box: Box3D, readonly sphere: Sphere3D }
/** transient result */
......
......@@ -10,7 +10,7 @@ import { Result, Lookup3D } from './common';
import { Box3D } from '../primitives/box3d';
import { Sphere3D } from '../primitives/sphere3d';
import { PositionData } from '../common';
import { Vec3, EPSILON } from '../../linear-algebra';
import { Vec3 } from '../../linear-algebra';
import { OrderedSet } from '../../../mol-data/int';
import { Boundary } from '../boundary';
import { FibonacciHeap } from '../../../mol-util/fibonacci-heap';
......@@ -42,11 +42,12 @@ class GridLookup3DImpl<T extends number = number> implements GridLookup3D<T> {
return ret;
}
nearest(x: number, y: number, z: number, k: number = 1, result?: Result<T>): Result<T> {
nearest(x: number, y: number, z: number, k: number = 1, stopIf?: Function, result?: Result<T>): Result<T> {
this.ctx.x = x;
this.ctx.y = y;
this.ctx.z = z;
this.ctx.k = k;
this.ctx.stopIf = stopIf;
const ret = result ?? this.result;
queryNearest(this.ctx, ret);
return ret;
......@@ -234,12 +235,13 @@ interface QueryContext {
y: number,
z: number,
k: number,
stopIf?: Function,
radius: number,
isCheck: boolean
}
function createContext(grid: Grid3D): QueryContext {
return { grid, x: 0.1, y: 0.1, z: 0.1, k: 1, radius: 0.1, isCheck: false };
return { grid, x: 0.1, y: 0.1, z: 0.1, k: 1, stopIf: undefined, radius: 0.1, isCheck: false };
}
function query<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
......@@ -294,124 +296,137 @@ function query<T extends number = number>(ctx: QueryContext, result: Result<T>):
const tmpDirVec = Vec3();
const tmpVec = Vec3();
const tmpMapG = new Map<number, boolean>();
const tmpSetG = new Set<number>();
const tmpSetG2 = new Set<number>();
const tmpArrG1 = [0.1];
const tmpArrG2 = [0.1];
const tmpArrG3 = [0.1];
const tmpHeapG = new FibonacciHeap();
function queryNearest<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
const { expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
const [minX, minY, minZ] = box.min;
const { x, y, z, k } = ctx;
const { min, expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
const { x, y, z, k, stopIf } = ctx;
const indicesCount = OrderedSet.size(indices);
Result.reset(result);
if (indicesCount === 0 || k <= 0) return false;
let gX, gY, gZ;
let gX, gY, gZ, stop = false, gCount = 1, expandGrid = true, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, maxRange = 0, expandRange = true, gridId: number, gridPointsFinished = false;
const expandedArrG = tmpArrG3, sqMaxRadius = maxRadius * maxRadius;
arrG.length = 0;
expandedArrG.length = 0;
tmpSetG.clear();
tmpHeapG.clear();
Vec3.set(tmpVec, x, y, z);
if (!Box3D.containsVec3(box, tmpVec)) {
// intersect ray pointing to box center
Box3D.nearestIntersectionWithRay(tmpVec, box, tmpVec, Vec3.normalize(tmpDirVec, Vec3.sub(tmpDirVec, center, tmpVec)));
gX = Math.max(0, Math.min(sX - 1, Math.floor((tmpVec[0] - minX) / delta[0])));
gY = Math.max(0, Math.min(sY - 1, Math.floor((tmpVec[1] - minY) / delta[1])));
gZ = Math.max(0, Math.min(sZ - 1, Math.floor((tmpVec[2] - minZ) / delta[2])));
gX = Math.max(0, Math.min(sX - 1, Math.floor((tmpVec[0] - min[0]) / delta[0])));
gY = Math.max(0, Math.min(sY - 1, Math.floor((tmpVec[1] - min[1]) / delta[1])));
gZ = Math.max(0, Math.min(sZ - 1, Math.floor((tmpVec[2] - min[2]) / delta[2])));
} else {
gX = Math.floor((x - minX) / delta[0]);
gY = Math.floor((y - minY) / delta[1]);
gZ = Math.floor((z - minZ) / delta[2]);
gX = Math.floor((x - min[0]) / delta[0]);
gY = Math.floor((y - min[1]) / delta[1]);
gZ = Math.floor((z - min[2]) / delta[2]);
}
let gCount = 1, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, prevFurthestDist = Number.MAX_VALUE, prevNearestDist = -Number.MAX_VALUE, nearestDist = Number.MAX_VALUE, findFurthest = true, furthestDist = - Number.MAX_VALUE, distSqG: number;
arrG.length = 0;
nextArrG.length = 0;
arrG.push(gX, gY, gZ);
tmpMapG.clear();
tmpHeapG.clear();
const dX = maxRadius !== 0 ? Math.max(1, Math.min(sX - 1, Math.ceil(maxRadius / delta[0]))) : 1;
const dY = maxRadius !== 0 ? Math.max(1, Math.min(sY - 1, Math.ceil(maxRadius / delta[1]))) : 1;
const dZ = maxRadius !== 0 ? Math.max(1, Math.min(sZ - 1, Math.ceil(maxRadius / delta[2]))) : 1;
arrG.push(gX, gY, gZ, (((gX * sY) + gY) * sZ) + gZ);
while (result.count < indicesCount) {
const arrGLen = gCount * 3;
for (let ig = 0; ig < arrGLen; ig += 3) {
gX = arrG[ig];
gY = arrG[ig + 1];
gZ = arrG[ig + 2];
const gridId = (((gX * sY) + gY) * sZ) + gZ;
if (tmpMapG.get(gridId)) continue; // already visited
tmpMapG.set(gridId, true);
distSqG = (gX - x) * (gX - x) + (gY - y) * (gY - y) + (gZ - z) * (gZ - z);
if (!findFurthest && distSqG > furthestDist && distSqG < nearestDist) continue;
// evaluate distances in the current grid point
const bucketIdx = grid[gridId];
if (bucketIdx !== 0) {
const ki = bucketIdx - 1;
const offset = bucketOffset[ki];
const count = bucketCounts[ki];
const end = offset + count;
for (let i = offset; i < end; i++) {
const idx = OrderedSet.getAt(indices, bucketArray[i]);
const dx = px[idx] - x;
const dy = py[idx] - y;
const dz = pz[idx] - z;
let distSq = dx * dx + dy * dy + dz * dz;
if (maxRadius !== 0) {
const r = radius![idx];
const sqR = r * r;
if (findFurthest && distSq > furthestDist) furthestDist = distSq + sqR;
distSq = distSq - sqR;
} else {
if (findFurthest && distSq > furthestDist) furthestDist = distSq;
}
if (distSq > prevNearestDist && distSq <= furthestDist) {
tmpHeapG.insert(distSq, idx);
nearestDist = tmpHeapG.findMinimum()!.key as unknown as number;
const arrGLen = gCount * 4;
for (let ig = 0; ig < arrGLen; ig += 4) {
gridId = arrG[ig + 3];
if (!tmpSetG.has(gridId)) {
tmpSetG.add(gridId);
gridPointsFinished = tmpSetG.size >= grid.length;
const bucketIdx = grid[gridId];
if (bucketIdx !== 0) {
const _maxRange = maxRange;
const ki = bucketIdx - 1;
const offset = bucketOffset[ki];
const count = bucketCounts[ki];
const end = offset + count;
for (let i = offset; i < end; i++) {
const bIdx = bucketArray[i];
const idx = OrderedSet.getAt(indices, bIdx);
const dx = px[idx] - x;
const dy = py[idx] - y;
const dz = pz[idx] - z;
let distSq = dx * dx + dy * dy + dz * dz;
if (maxRadius !== 0) {
const r = radius![idx];
distSq -= r * r;
}
if (expandRange && distSq > maxRange) {
maxRange = distSq;
}
tmpHeapG.insert(distSq, bIdx);
}
if (_maxRange < maxRange) expandRange = false;
}
if (prevFurthestDist < furthestDist) findFurthest = false;
}
}
// find next grid points
nextArrG.length = 0;
nextGCount = 0;
tmpSetG2.clear();
for (let ig = 0; ig < arrGLen; ig += 4) {
gX = arrG[ig];
gY = arrG[ig + 1];
gZ = arrG[ig + 2];
// fill grid points array with valid adiacent positions
for (let ix = -1; ix <= 1; ix++) {
for (let ix = -dX; ix <= dX; ix++) {
const xPos = gX + ix;
if (xPos < 0 || xPos >= sX) continue;
for (let iy = -1; iy <= 1; iy++) {
for (let iy = -dY; iy <= dY; iy++) {
const yPos = gY + iy;
if (yPos < 0 || yPos >= sY) continue;
for (let iz = -1; iz <= 1; iz++) {
for (let iz = -dZ; iz <= dZ; iz++) {
const zPos = gZ + iz;
if (zPos < 0 || zPos >= sZ) continue;
const gridId = (((xPos * sY) + yPos) * sZ) + zPos;
if (tmpMapG.get(gridId)) continue; // already visited
nextArrG.push(xPos, yPos, zPos);
gridId = (((xPos * sY) + yPos) * sZ) + zPos;
if (tmpSetG2.has(gridId)) continue; // already scanned
tmpSetG2.add(gridId);
if (tmpSetG.has(gridId)) continue; // already visited
if (!expandGrid) {
const xP = min[0] + xPos * delta[0] - x;
const yP = min[1] + yPos * delta[1] - y;
const zP = min[2] + zPos * delta[2] - z;
const distSqG = (xP * xP) + (yP * yP) + (zP * zP) - sqMaxRadius; // is sqMaxRadius necessary?
if (distSqG > maxRange) {
expandedArrG.push(xPos, yPos, zPos, gridId);
continue;
}
}
nextArrG.push(xPos, yPos, zPos, gridId);
nextGCount++;
}
}
}
}
expandGrid = false;
if (nextGCount === 0) {
while (!tmpHeapG.isEmpty() && result.count < k) {
while (!tmpHeapG.isEmpty() && (gridPointsFinished || tmpHeapG.findMinimum()!.key as unknown as number <= maxRange) && result.count < k) {
const node = tmpHeapG.extractMinimum();
if (!node) throw new Error('Cannot extract minimum, should not happen');
const { key: squaredDistance, value: index } = node;
const squaredDistance = node!.key, index = node!.value;
Result.add(result, index as number, squaredDistance as number);
if (stopIf && !stop) {
stop = stopIf(index, squaredDistance);
}
}
if (result.count >= k) return result.count > 0;
prevNearestDist = nearestDist;
if (furthestDist === nearestDist) {
findFurthest = true;
prevFurthestDist = furthestDist;
nearestDist = Number.MAX_VALUE;
} else {
nearestDist = furthestDist + EPSILON; // adding EPSILON fixes a bug
}
// resotre visibility of current gid points
for (let ig = 0; ig < arrGLen; ig += 3) {
gX = arrG[ig];
gY = arrG[ig + 1];
gZ = arrG[ig + 2];
const gridId = (((gX * sY) + gY) * sZ) + gZ;
tmpMapG.set(gridId, false);
if (result.count >= k || stop || result.count >= indicesCount) return result.count > 0;
expandGrid = true;
expandRange = true;
if (expandedArrG.length > 0) {
for (let i = 0, l = expandedArrG.length; i < l; i++) {
arrG.push(expandedArrG[i]);
}
expandedArrG.length = 0;
gCount = arrG.length;
}
} else {
const tmp = arrG;
arrG = nextArrG;
nextArrG = tmp;
nextArrG.length = 0;
gCount = nextGCount;
nextGCount = 0;
}
}
return result.count > 0;
......
......@@ -53,19 +53,16 @@ export function StructureLookup3DResultContext(): StructureLookup3DResultContext
return { result: StructureResult.create(), closeUnitsResult: Result.create(), unitGroupResult: Result.create() };
}
const tmpHeap = new FibonacciHeap();
export class StructureLookup3D {
private unitLookup: Lookup3D;
private pivot = Vec3();
private tmpHeap = new FibonacciHeap();
findUnitIndices(x: number, y: number, z: number, radius: number): Result<number> {
return this.unitLookup.find(x, y, z, radius);
}
nearestUnitIndices(x: number, y: number, z: number, k: number = 1): Result<number> {
return this.unitLookup.nearest(x, y, z, k);
}
private findContext = StructureLookup3DResultContext();
find(x: number, y: number, z: number, radius: number, ctx?: StructureLookup3DResultContext): StructureResult {
......@@ -99,11 +96,11 @@ export class StructureLookup3D {
_nearest(x: number, y: number, z: number, k: number, ctx: StructureLookup3DResultContext): StructureResult {
const result = ctx.result;
const heap = this.tmpHeap;
Result.reset(result);
this.tmpHeap.clear();
tmpHeap.clear();
const { units } = this.structure;
const closeUnits = this.unitLookup.nearest(x, y, z, units.length, ctx.closeUnitsResult); // sort all units based on distance to the point
let elementsCount = 0;
const closeUnits = this.unitLookup.nearest(x, y, z, units.length, (uid: number) => (elementsCount += units[uid].elements.length) >= k, ctx.closeUnitsResult); // sort units based on distance to the point
if (closeUnits.count === 0) return result;
let totalCount = 0, maxDistResult = -Number.MAX_VALUE;
for (let t = 0, _t = closeUnits.count; t < _t; t++) {
......@@ -115,16 +112,16 @@ export class StructureLookup3D {
Vec3.transformMat4(this.pivot, this.pivot, unit.conformation.operator.inverse);
}
const unitLookup = unit.lookup3d;
const groupResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2], k, ctx.unitGroupResult);
const groupResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2], k, void 0, ctx.unitGroupResult);
if (groupResult.count === 0) continue;
maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
totalCount += groupResult.count;
maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
for (let j = 0, _j = groupResult.count; j < _j; j++) {
heap.insert(groupResult.squaredDistances[j], { index: groupResult.indices[j], unit: unit });
tmpHeap.insert(groupResult.squaredDistances[j], { index: groupResult.indices[j], unit: unit });
}
}
while (!heap.isEmpty() && result.count < k) {
const node = heap.extractMinimum();
while (!tmpHeap.isEmpty() && result.count < k) {
const node = tmpHeap.extractMinimum();
if (!node) throw new Error('Cannot extract minimum, should not happen');
const { key: squaredDistance } = node;
const { unit, index } = node.value as { index: UnitIndex, unit: Unit };
......
/**
* Copyright (c) 2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
*
* @author Gianluca Tomasello <giagitom@gmail.com>
*/
import { FibonacciHeap } from '../fibonacci-heap';
describe('fibonacci-heap', () => {
it('basic', () => {
const heap = new FibonacciHeap();
heap.insert(1, 2);
heap.insert(4);
heap.insert(2);
heap.insert(3);
expect(heap.size()).toBe(4);
const node = heap.extractMinimum();
expect(node!.key).toBe(1);
expect(node!.value).toBe(2);
expect(heap.size()).toBe(3);
});
});
/**
* Copyright (c) 2018-2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
* Copyright (c) 2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
*
* @author Gianluca Tomasello <giagitom@gmail.com>
*
* Adapted from https://github.com/gwtw/ts-fibonacci-heap, Copyright (c) 2014 Daniel Imms, MIT
*/
type CompareFunction<K, V> = (a: INode<K, V>, b: INode<K, V>) => number;
interface INode<K, V> {
key: K;
value?: V;
}
type CompareFunction<K, V> = (a: INode<K, V>, b: INode<K, V>) => number;
class Node<K, V> implements INode<K, V> {
public key: K;
public value: V | undefined;
......@@ -35,28 +35,32 @@ class Node<K, V> implements INode<K, V> {
class NodeListIterator<K, V> {
private _index: number;
private _items: Node<K, V>[];
private _len: number;
/**
* Creates an Iterator used to simplify the consolidate() method. It works by
* making a shallow copy of the nodes in the root list and iterating over the
* shallow copy instead of the source as the source will be modified.
* @param start A node from the root list.
*/
constructor(start: Node<K, V>) {
constructor(start?: Node<K, V>) {
this._index = -1;
this._items = [];
let current = start;
do {
this._items.push(current);
current = current.next;
} while (start !== current);
this._len = 0;
if (start) {
let current = start, l = 0;
do {
this._items[l++] = current;
current = current.next;
} while (start !== current);
this._len = l;
}
}
/**
* @return Whether there is a next node in the iterator.
*/
public hasNext(): boolean {
return this._index < this._items.length - 1;
return this._index < this._len - 1;
}
/**
......@@ -65,8 +69,23 @@ class NodeListIterator<K, V> {
public next(): Node<K, V> {
return this._items[++this._index];
}
/**
* @return Resets iterator to reuse it.
*/
public reset(start: Node<K, V>) {
this._index = -1;
this._len = 0;
let current = start, l = 0;
do {
this._items[l++] = current;
current = current.next;
} while (start !== current);
this._len = l;
}
}
const tmpIt = new NodeListIterator<any, any>();
/**
* A Fibonacci heap data structure with a key and optional value.
*/
......@@ -277,9 +296,9 @@ export class FibonacciHeap<K, V> {
private _consolidate(minNode: Node<K, V>): Node<K, V> | null {
const aux = [];
const it = new NodeListIterator<K, V>(minNode);
while (it.hasNext()) {
let current = it.next();
tmpIt.reset(minNode);
while (tmpIt.hasNext()) {
let current = tmpIt.next();
// If there exists another node with the same degree, merge them
let auxCurrent = aux[current.degree];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment