Sep 2018 Rational Exuberance

k-nearest neighbors approximation with locality sensitive hashing in SQL

Nearest neighbors are difficult to calculate at scale. For one because of it's quadratic runtime, but also because an exact solution has even for a small k quadratic storage requirements (or needs multiple passes over the data, making the runtime even worse). Both approaches are often not feasible in practice. There are quite a few different ways to approximate an exact solution however. And since the available distance metric often is only an approximation anyway, this can be a good approach to speedup the calculation.

One very interesting technique for an approximation via locality sensitive hashing is described in chapter 3 of MMDS[1]. It is interesting because it can correctly approximate the chosen distance metric without any assumptions about the distribution of the data, while often only having (near) linear storage and runtime requirements. This requires to model the problem specifically for the chosen distance metric however, and some careful tuning of a few parameters.

Unfortunately while the approximation quality is less dependent on the data distribution, both the exact runtime and memory of this algorithm depend heavily on the data distribution (and the required accuracy of the results). In practice both linear runtime and memory can be achieved however, albeit with a possibly not insignificant constant factor. The algorithm can be parallelized, though this is trivial for the exact solution as well.

The high-level idea of the algorithm is the following: choose a random but deterministic lossy representation of the data points. Only data points that are equal within this representation (i.e. hashed into the same bucket) are candidates for nearest neighbors. The more lossy the chosen representation is, the more candidates we'll find. To avoid a representation that is too lossy, we can choose multiple different ones and merge candidate pairs from different representations.

Simplified example: we want to find customers who bought the same products. We pick five random products, and all customers who bought these products are candidates for being nearest neighbors. To find more candidates, we pick another five random products. This will of course not find all similar customers, and neither does it guarantee that customers who bought these five products are really similar (according to our metric). Both errors can be significantly reduced though by varying the number of random we pick, and increasing the number of picks we make.

This can be seen as a pre-processing step for an exact KNN algorithm, that filters out a large number of candidate pairs which don't need to be considered by the exact algorithm. The parameters how many products we pick for the representation, and how many of these sample representations we pick need to be carefully chosen though. Increasing the number of products decreases the number of candidates, but can also filter out actually similar candidates. Increasing the number of samples generally improves the results, but also increases the runtime.

To see how the number of products and the number of samples influence the results, we can take a look at the following chart. What the line shows, is the likelihood of two candidates being considered as nearest neighbors, given their actual similarity. The higher the actual similarity is, the more likely they will also be considered as nearest neighbors. Increasing the number of products means less candidates will be considered. Increasing the samples means more candidates will be considered. Usually the curve is quite steep, almost like a step function. This allows us to filter (almost) all candidate pairs with a similarity of less than 50%, while including most candidate pairs with a higher similarity. If most candidate pairs have a higher similarity though, this will not help much.

Number of products per sample: Number of samples:

For the implementation of the actual algorithm we first have to choose an actual distance metric. Then we have to calculate an array of signatures for each candidate. The probability that two candidates have the same signature at one position of the array should equal their actual distance.

For the example of customers who buy similar products, we use the distance metric as the Jaccard distance of their product sets. For the signature we calculate a hash of all the products a user bought, and use the minimum value of these hashes as signature. This resembles the Jaccard distance, because the probability that two users will have the same signature is exactly the number of equal hashes in each list (which is the number of same products they bought), divided by the probability that one of them is the minimum hash value (which is the number distinct products in both lists). Calculating the signatures in this way is better than choosing random products, because it means that we calculate a meaningful signature for every user and not just for those who bought the products.

We can compute multiple signatures then by choosing different hash functions for each of them. To choose the threshold as shown in the graph above, we would have to calculate 128 signatures for each users, which we then split into 16 groups of 8 signatures each. If all 8 signatures within at least one of the 16 groups are equal for two users, we consider them as a candidate for being nearest neighbors.

Now this can be implemented quite well directly in SQL. We first generate some example data with users and products they bought.

CREATE TABLE knn.user_products (
  user_id    INT,
  product_id INT
);

INSERT INTO knn.user_products
  SELECT DISTINCT
    ((23 * random() * n) / 10) :: INT AS user_id,
    (random() * 20) :: INT            AS product_id
  FROM generate_series(1, 500) n;

Then we calculate the signatures for each user. For each signature we calculate a random but deterministic order of all products, and take the first product that the user bought from that list as the signature_value. The probability that two users have the same signature_value at the same signature_index then resembles exactly the Jaccard similarity of the sets of products both users bought.

CREATE TABLE knn.signature_parameters (
  signature_index SMALLINT,
  signature_key   SMALLINT
);

-- Parameters for the different hash functions.
INSERT INTO knn.signature_parameters VALUES
  -- We actually only calculate 6 signatures for each user here.
  (0, 7), (1, 11), (2, 13), (3, 17), (4, 19), (5, 23);

CREATE TABLE knn.user_signature (
  user_id         INT,
  signature_band  SMALLINT,
  signature_index SMALLINT,
  signature_value INTEGER
);

INSERT INTO knn.user_signature
  SELECT
    up.user_id,
    -- And use only 3 bands with 2 signatures each.
    (sp.signature_index / 2) :: INT AS signature_band,
    sp.signature_index,
    min(('x' || substring(md5((sp.signature_key * product_id) :: TEXT), 1, 4)) :: BIT(32) :: INT) AS signature_value
  FROM knn.user_products up
  CROSS JOIN knn.signature_parameters sp
  GROUP BY up.user_id, sp.signature_index;

And in the final step we group users with the same signature values for a set of signature indexes (the signature_band) into buckets.

WITH matching_bands AS (
  SELECT
    base.user_id AS "user",
    other.user_id AS neighbor,
    signature_band
  FROM knn.user_signature base
  -- Cross join all users with equal signature values at a specific index.
  JOIN knn.user_signature other USING (signature_band, signature_index, signature_value)
  WHERE base.user_id != other.user_id
  GROUP BY base.user_id, other.user_id, signature_band
  -- But only keep bands where all signatures match.
  HAVING count(DISTINCT signature_index) = 2
),
num_bands AS (
  SELECT
    "user",
    neighbor,
    count(DISTINCT signature_band) AS num_matching_bands
  FROM matching_bands
  GROUP BY "user", neighbor
)
SELECT
  "user",
  neighbor,
  rank
FROM (
  SELECT
    "user",
    neighbor,
    num_matching_bands,
    row_number() OVER (PARTITION BY "user" ORDER BY num_matching_bands DESC) AS rank
  FROM num_bands
) u
WHERE rank <= 2 -- get top 2 nearest neighbors.
ORDER BY "user", rank ASC, neighbor

Some additional notes:


References: