k-nearest neighbors approximation with locality sensitive hashing in SQL
Nearest neighbors are difficult to calculate at scale. For
one because of it's quadratic runtime, but also because an
exact solution has even for a small k
quadratic
storage requirements (or needs multiple passes over the
data, making the runtime even worse). Both approaches are
often not feasible in practice. There are quite a few
different ways to approximate an exact solution however. And
since the available distance metric often is only an
approximation anyway, this can be a good approach to speedup
the calculation.
One very interesting technique for an approximation via locality sensitive hashing is described in chapter 3 of MMDS[1]. It is interesting because it can correctly approximate the chosen distance metric without any assumptions about the distribution of the data, while often only having (near) linear storage and runtime requirements. This requires to model the problem specifically for the chosen distance metric however, and some careful tuning of a few parameters.
Unfortunately while the approximation quality is less dependent on the data distribution, both the exact runtime and memory of this algorithm depend heavily on the data distribution (and the required accuracy of the results). In practice both linear runtime and memory can be achieved however, albeit with a possibly not insignificant constant factor. The algorithm can be parallelized, though this is trivial for the exact solution as well.
The high-level idea of the algorithm is the following: choose a random but deterministic lossy representation of the data points. Only data points that are equal within this representation (i.e. hashed into the same bucket) are candidates for nearest neighbors. The more lossy the chosen representation is, the more candidates we'll find. To avoid a representation that is too lossy, we can choose multiple different ones and merge candidate pairs from different representations.
Simplified example: we want to find customers who bought the same products. We pick five random products, and all customers who bought these products are candidates for being nearest neighbors. To find more candidates, we pick another five random products. This will of course not find all similar customers, and neither does it guarantee that customers who bought these five products are really similar (according to our metric). Both errors can be significantly reduced though by varying the number of random we pick, and increasing the number of picks we make.
This can be seen as a pre-processing step for an exact KNN algorithm, that filters out a large number of candidate pairs which don't need to be considered by the exact algorithm. The parameters how many products we pick for the representation, and how many of these sample representations we pick need to be carefully chosen though. Increasing the number of products decreases the number of candidates, but can also filter out actually similar candidates. Increasing the number of samples generally improves the results, but also increases the runtime.
To see how the number of products and the number of samples influence the results, we can take a look at the following chart. What the line shows, is the likelihood of two candidates being considered as nearest neighbors, given their actual similarity. The higher the actual similarity is, the more likely they will also be considered as nearest neighbors. Increasing the number of products means less candidates will be considered. Increasing the samples means more candidates will be considered. Usually the curve is quite steep, almost like a step function. This allows us to filter (almost) all candidate pairs with a similarity of less than 50%, while including most candidate pairs with a higher similarity. If most candidate pairs have a higher similarity though, this will not help much.
For the implementation of the actual algorithm we first have to choose an actual distance metric. Then we have to calculate an array of signatures for each candidate. The probability that two candidates have the same signature at one position of the array should equal their actual distance.
For the example of customers who buy similar products, we use the distance metric as the Jaccard distance of their product sets. For the signature we calculate a hash of all the products a user bought, and use the minimum value of these hashes as signature. This resembles the Jaccard distance, because the probability that two users will have the same signature is exactly the number of equal hashes in each list (which is the number of same products they bought), divided by the probability that one of them is the minimum hash value (which is the number distinct products in both lists). Calculating the signatures in this way is better than choosing random products, because it means that we calculate a meaningful signature for every user and not just for those who bought the products.
We can compute multiple signatures then by choosing different hash functions for each of them. To choose the threshold as shown in the graph above, we would have to calculate 128 signatures for each users, which we then split into 16 groups of 8 signatures each. If all 8 signatures within at least one of the 16 groups are equal for two users, we consider them as a candidate for being nearest neighbors.
Now this can be implemented quite well directly in SQL. We first generate some example data with users and products they bought.
CREATE TABLE knn.user_products (
user_id INT,
product_id INT
);
INSERT INTO knn.user_products
SELECT DISTINCT
((23 * random() * n) / 10) :: INT AS user_id,
(random() * 20) :: INT AS product_id
FROM generate_series(1, 500) n;
Then we calculate the signatures for each user. For each
signature we calculate a random but deterministic order of
all products, and take the first product that the user
bought from that list as the signature_value
.
The probability that two users have the
same signature_value
at the
same signature_index
then resembles exactly the
Jaccard similarity of the sets of products both users
bought.
CREATE TABLE knn.signature_parameters (
signature_index SMALLINT,
signature_key SMALLINT
);
-- Parameters for the different hash functions.
INSERT INTO knn.signature_parameters VALUES
-- We actually only calculate 6 signatures for each user here.
(0, 7), (1, 11), (2, 13), (3, 17), (4, 19), (5, 23);
CREATE TABLE knn.user_signature (
user_id INT,
signature_band SMALLINT,
signature_index SMALLINT,
signature_value INTEGER
);
INSERT INTO knn.user_signature
SELECT
up.user_id,
-- And use only 3 bands with 2 signatures each.
(sp.signature_index / 2) :: INT AS signature_band,
sp.signature_index,
min(('x' || substring(md5((sp.signature_key * product_id) :: TEXT), 1, 4)) :: BIT(32) :: INT) AS signature_value
FROM knn.user_products up
CROSS JOIN knn.signature_parameters sp
GROUP BY up.user_id, sp.signature_index;
And in the final step we group users with the same signature
values for a set of signature indexes
(the signature_band
) into buckets.
WITH matching_bands AS (
SELECT
base.user_id AS "user",
other.user_id AS neighbor,
signature_band
FROM knn.user_signature base
-- Cross join all users with equal signature values at a specific index.
JOIN knn.user_signature other USING (signature_band, signature_index, signature_value)
WHERE base.user_id != other.user_id
GROUP BY base.user_id, other.user_id, signature_band
-- But only keep bands where all signatures match.
HAVING count(DISTINCT signature_index) = 2
),
num_bands AS (
SELECT
"user",
neighbor,
count(DISTINCT signature_band) AS num_matching_bands
FROM matching_bands
GROUP BY "user", neighbor
)
SELECT
"user",
neighbor,
rank
FROM (
SELECT
"user",
neighbor,
num_matching_bands,
row_number() OVER (PARTITION BY "user" ORDER BY num_matching_bands DESC) AS rank
FROM num_bands
) u
WHERE rank <= 2 -- get top 2 nearest neighbors.
ORDER BY "user", rank ASC, neighbor
Some additional notes:
- One of the main challenges when implementing the algorithm is to figure out how the signatures are calculated.
- The runtime and memory requirements of this algorithm depend heavily on the possibility to filter out a majority of candidate pairs that don't need to be considered as nearest neighbors. If distances between all candidates are very similar and/or very skewed across certain clusters, this filtering may not work well. Worst case, if the threshold isn't properly configured the pre-processing basically does nothing while still taking a considerable amount of resources. Creating a sampled histogram of distances between candidates can help to find proper parameters.
- The algorithm generally does not find the same number of nearest neighbors for each candidate. Especially, it may not find nearest neighbors for some candidates at all (without increasing runtime too much). For many use cases this is not too much of a problem, and can either be ignored or i.e. be solved by using some even simpler heuristic for those candidates. If the requirement is strictly to find neighbors for all candidates however, this approximation is probably not the best choice.
- There are quite a few other KNN approximations that are generally easier to model and more deterministic regarding their runtime. They often also make more assumptions about the data though, and don't necessarily achieve linear storage and/or runtime.
References: