π Guest Post: Introduction to DiskANN and the Vamana Algorithm*
In this tutorial, Frank Liu, Solutions Architect at Zilliz, will deep dive into DiskANN, a graph-based indexing strategy, their first foray into on-disk indexes. Like HNSW, DiskANN avoids the problem of figuring out how and where to partition a high-dimensional input space and instead relies on building a directed graph to the relationship between nearby vectors. As the volume of unstructured data continues to grow in the upcoming decade, the need for on-disk indexes will likely rise, as will research around this area. Letβs explore!
Approximate Nearest Neighbors Oh Yeah (Annoy for short) is a tree-based indexing algorithm that uses random projections to iteratively split the hyperspace of vectors, with the final split resulting in a binary tree. Annoy uses two tricks to improve accuracy/recall - 1) traversing down both halves of a split if the query point lies close to the dividing hyperplane, and 2) creating a forest of binary trees. Although Annoy isn't commonly used as an indexing algorithm in production environments today (`HNSW` and `IVF_PQ` are far more popular), Annoy still sets a strong baseline for tree-based vector indexes.
At its core, Annoy is still an in-memory index. In previous articles, weβve only looked at in-memory indexes - vector indexes that reside entirely in RAM. On commodity machines, in-memory indexes are excellent for smaller datasets (up to around 10 million 1024-dimensional vectors). Still, once we move past 100M vectors, in-memory indexes can be prohibitively expensive. For example, 100M vectors alone will require approximately 400GB of RAM.
Here's where an on-disk index - a vector index that utilizes both RAM and hard disk - would be helpful. In this tutorial, we'll dive into DiskANN, a graph-based vector index that enables large-scale storage, indexing, and search of vectors by persisting the bulk of the index on NVMe hard disks. We'll first cover Vamana, the core data structure behind DiskANN, before discussing how the on-disk portion of DiskANN utilizes a Vamana graph to perform queries efficiently. Like previous tutorials, we'll also develop our implementation of the Vamana algorithm in Python.
The Vamana algorithm
Vamana's key concept is the relative neighborhood graph (RNG). Formally, edges in an RNG for a single point are constructed iteratively so long as a new edge is not closer to any existing neighbor. If this is difficult to wrap your head around, no worries - the key concept is that RNGs are constructed so that only a subset of the most relevant nearest edges are added for any single point in the graph. As with HNSW, nearby vectors are determined by the distance metric that's being used in the vector database, e.g., cosine or L2.
There are two main problems with RNGs that make them still too inefficient for vector search. The first is that constructing an RNG is prohibitively expensive:
The second is that setting the diameter of an RNG is difficult. High-diameter RNGs are too dense, while RNGs with low diameters make graph traversal (querying the index) lengthy and inefficient. Despite this, RNGs remain a good starting point and form the basis for the Vamana algorithm.
In broad terms, the Vamana algorithm solves both of these problems by making use of two clever heuristics: the greedy search procedure and the robust prune procedure. Let's walk through both of these, along with an implementation, to see how these work together to create an optimized graph for vector search.
As the name implies, the greedy search algorithm iteratively searches for the closest neighbors to a specified point (vector) in the graph `p`. Loosely speaking, we maintain two sets: a set of nearest neighbors `nns` and a set of visited nodes `visit`.
def _greedy_search(graph, start, query, nq: int, L: int):
Β Β Β Β best = (np.linalg.norm(graph[start][0] - query), entry)
Β Β Β Β nns = [start]
Β Β Β Β visit = set()Β # set of visited nodes
Β Β Β Β nns = heapify(nns)
Β Β Β Β # find top-k nearest neighbors
Β Β Β Β while nns - visit:
Β Β Β Β Β Β Β Β nn = nns[0]
Β Β Β Β Β Β Β Β for idx in nn[1]:
Β Β Β Β Β Β Β Β Β Β Β Β d = np.linalg.norm(graph[idx][0] - query)
Β Β Β Β Β Β Β Β Β Β Β Β heappush(nns, (d, nn))
Β Β Β Β Β Β Β Β Β Β Β Β visit.add((d, nn))
Β Β Β Β Β Β Β Β # retain up to search list size elements
Β Β Β Β Β Β Β Β while len(nns) > L:
Β Β Β Β Β Β Β Β Β Β Β Β heappop(nns)
Β Β Β Β return (nns[:nq], visit)
`nns` is initialized with the starting node, and at each iteration, we take up to `L` steps in the direction closest to our query point. This continues until all nodes in `nns` have been visited.
Robust prune, on the other hand, is a bit more involved. This heuristic is designed to ensure that the distance between consecutive searched nodes in the greedy search procedure decreases exponentially. Formally, robust prune, when called on a single node, will ensure that the outbound edges are modified such that there are at most `R` edges, with a new edge pointing to a node at least `a` times distant from any existing neighbor.
def _robust_prune(graph, node: Tuple[np.ndarray, Set[int]], candid: Set[int], R: int):
Β Β Β Β candid.update(node[1])
Β Β Β Β node[1].clear()
Β Β Β Β while candid:
Β Β Β Β Β Β Β Β (min_d, nn) = (float("inf"), None)
Β Β Β Β Β Β Β Β # find the closest element/vector to input node
Β Β Β Β Β Β Β Β for k in candid:
Β Β Β Β Β Β Β Β Β Β Β Β p = graph[k]
Β Β Β Β Β Β Β Β Β Β Β Β d = np.linalg.norm(node[0] - p[0])
Β Β Β Β Β Β Β Β Β Β Β Β if d < min_d:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β (min_d, nn) = (d, p)
Β Β Β Β Β Β Β Β node[1].add(nn)
Β Β Β Β Β Β Β Β # set at most R out-neighbors for the selected node
Β Β Β Β Β Β Β Β if len(node[1]) == R:
Β Β Β Β Β Β Β Β Β Β Β Β break
Β Β Β Β Β Β Β Β # future iterations must obey distance threshold
Β Β Β Β Β Β Β Β for p in candid:
Β Β Β Β Β Β Β Β Β Β Β Β if a * min_d <= np.linalg.norm(node[0] - p[0]):
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β candid.remove(p)
With these two heuristics, we can now focus on the full Vamana algorithm. A Vamana graph is first initialized so each node has `R` random outbound edges. The algorithm then iteratively calls `_greedy_search` and `_robust_prune` for all nodes within the graph.
As we've done for all of our previous tutorials on vector indexes, let's now put it all together into a single script:
class VamanaIndex(_BaseIndex):
Β Β Β Β """Vamana graph algorithm implementation. Every element in each graph is a
Β Β Β Β 2-tuple containing the vector and a list of unidirectional connections
Β Β Β Β within the graph.
Β Β Β Β """
Β Β Β Β def __init__(self, L: int = 10, a: float = 1.5, R: int = 10):
Β Β Β Β Β Β Β Β super().__init__()
Β Β Β Β Β Β Β Β self._L = L
Β Β Β Β Β Β Β Β self._a = a
Β Β Β Β Β Β Β Β self._R = R
Β Β Β Β Β Β Β Β self._start = NoneΒ # index of starting vector
Β Β Β Β Β Β Β Β self._index = []
Β Β Β Β def create(self, dataset):
Β Β Β Β Β Β Β Β self._R = min(self._R, len(dataset))
Β Β Β Β Β Β Β Β # intialize graph with dataset
Β Β Β Β Β Β Β Β # set starting location as medoid vector
Β Β Β Β Β Β Β Β dist = float("inf")
Β Β Β Β Β Β Β Β medoid = np.median(dataset, axis=0)
Β Β Β Β Β Β Β Β for (n, vec) in enumerate(dataset):
Β Β Β Β Β Β Β Β Β Β Β Β d = np.linalg.norm(medoid - vec)
Β Β Β Β Β Β Β Β Β Β Β Β if d < dist:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β dist = d
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β self._start = n
Β Β Β Β Β Β Β Β Β Β Β Β self._index.append((vec, set()))
Β Β Β Β Β Β Β Β # randomize out-connections for each node
Β Β Β Β Β Β Β Β for (n, node) in enumerate(self._index):
Β Β Β Β Β Β Β Β Β Β Β Β idxs = np.random.choice(len(self._index) - 1, replace=False, size=(self._R,))
Β Β Β Β Β Β Β Β Β Β Β Β idxs[idxs >= n] += 1Β # ensure no node points to itself
Β Β Β Β Β Β Β Β Β Β Β Β node[1].update(idxs)
Β Β Β Β Β Β Β Β # random permutation + sequential graph update
Β Β Β Β Β Β Β Β for (n, node) in enumerate(self._index):
Β Β Β Β Β Β Β Β Β Β Β Β (_, V) = self.search(node, nq=1)
Β Β Β Β Β Β Β Β Β Β Β Β self._robust_prune(node, V)
Β Β Β Β Β Β Β Β Β Β Β Β for inb in node[1]:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β nbr = self._index[inb]
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β if len(nbrs[1]) > self._R:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β self._robust_prune(nbr, nbr[1].union(n))
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β else:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β nbr[1].add(n)
Β Β Β Β def insert(self, vector):
Β Β Β Β Β Β Β Β raise NotImplementedError("Vamana indexes are static")
Β Β Β Β def search(query, nq: int = 10):
Β Β Β Β Β Β Β Β """Greedy search.
Β Β Β Β Β Β Β Β """
Β Β Β Β Β Β Β Β best = (np.linalg.norm(self._index[self._start][0] - query), entry)
Β Β Β Β Β Β Β Β nns = [start]
Β Β Β Β Β Β Β Β visit = set()Β # set of visited nodes
Β Β Β Β Β Β Β Β nns = heapify(nns)
Β Β Β Β Β Β Β Β # find top-k nearest neighbors
Β Β Β Β Β Β Β Β while nns - visit:
Β Β Β Β Β Β Β Β Β Β Β Β nn = nns[0]
Β Β Β Β Β Β Β Β Β Β Β Β for idx in nn[1]:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β d = np.linalg.norm(self._index[idx][0] - query)
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β heappush(nns, (d, nn))
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β visit.add((d, nn))
Β Β Β Β Β Β Β Β Β Β Β Β # retain up to search list size elements
Β Β Β Β Β Β Β Β Β Β Β Β while len(nns) > self._L:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β heappop(nns)
Β Β Β Β Β Β Β Β return (nns[:nq], visit)
Β Β Β Β def _robust_prune(node: Tuple[np.ndarray, Set[int]], candid: Set[int]):
Β Β Β Β Β Β Β Β candid.update(node[1])
Β Β Β Β Β Β Β Β node[1].clear()
Β Β Β Β Β Β Β Β while candid:
Β Β Β Β Β Β Β Β Β Β Β Β (min_d, nn) = (float("inf"), None)
Β Β Β Β Β Β Β Β Β Β Β Β # find the closest element/vector to input node
Β Β Β Β Β Β Β Β Β Β Β Β for k in candid:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β p = self._index[k]
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β d = np.linalg.norm(node[0] - p[0])
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β if d < min_d:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β (min_d, nn) = (d, p)
Β Β Β Β Β Β Β Β Β Β Β Β node[1].add(nn)
Β Β Β Β Β Β Β Β Β Β Β Β # set at most R out-neighbors for the selected node
Β Β Β Β Β Β Β Β Β Β Β Β if len(node[1]) == R:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β break
Β Β Β Β Β Β Β Β Β Β Β Β # future iterations must obey distance threshold
Β Β Β Β Β Β Β Β Β Β Β Β for p in candid:
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β if a * min_d <= np.linalg.norm(node[0] - p[0]):
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β candid.remove(p)
That's it for Vamana!
All code for this tutorial is freely available at https://github.com/fzliu/vector-search.