Mar 2019 Rational Exuberance

Distributed k-means clustering in SQL

Clustering data points into groups is a pretty common use case. With an increasing number of dimensions this quickly becomes infeasible to do manually. A first step to compute a clustering may be the relatively simple k-means algorithm (which technically is a heuristic for the k-means problem)[1]. While it does not always deliver good results, it often works quite well.

The interactive demo below gives an idea how the algorithm works. As one of the caveats, the number of clusters k needs to be specified as parameter for the algorithm. The algorithm then follows these steps:

  1. Pick k random points as cluster centers.
  2. Assign each point to the cluster of its nearest cluster center.
  3. Update the cluster centers to be the average of all points now assigned to the cluster.
Steps 2 and 3 are repeated until the algorithm converges, i.e. the clusters do not change anymore.

Interactive simulation of the k-means algorithm. You can change the parameters for the normally distributed actual clusters and then see if the algorithm can find them. Reset the clustering to try different initial clusters and you will probably also see cases where the algorithm does not find the optimal solution.
Mean x Variance x Mean y Variance y Points

While the loop around steps 2 and 3 of the algorithm is naturally a bit problematic for an implementation in SQL, each of the steps is very simple to implement. And in a distributed database the computations should parallelize pretty much automatically if the data is distributed in a way that allows most computations to be local.

Assuming there is a large number of data points, we make sure the table for them is evenly distributed in some deterministic way.

  point_id INTEGER,
  x1       FLOAT,
  x2       FLOAT

We can then initialize the cluster centers by picking k random points from the original table as cluster centers of the first iteration (step 1 of the algorithm). Assuming that the number of clusters is relatively small, this table should be distributed with relatively low cost to all nodes.

CREATE TABLE cluster_centers AS (
  iteration  INTEGER,
  cluster_id INTEGER,
  x1         FLOAT,
  x2         FLOAT

INSERT INTO cluster_centers
    0 AS iteration,
    row_number() OVER () AS cluster_id,
  FROM points
  ORDER BY random()
  LIMIT 3; -- k = 3

For step 2 of the algorithm, we compute distances for each point to each cluster center and pick the cluster with the smallest distance. Since the cluster_centers table is distributed to all nodes, this should be an entirely local operation.

CREATE TABLE clusters (
  point_id   INTEGER,
  cluster_id INTEGER

INSERT INTO clusters
  FROM (
      row_number() OVER (PARTITION BY p.point_id ORDER BY sqrt(((p.x1 - c.x1)^2 * (p.x2 - c.x2)^2)) ASC) AS rank
    FROM points p
    CROSS JOIN cluster_centers c
    WHERE c.iteration = (SELECT max(iteration) FROM cluster_centers)
  ) rankings
  WHERE rank = 1;

Updating the cluster centers in step 3 of the algorithm is now a simple average aggregation grouped by the previously calculated clusters. This operation is not entirely local, but should still be relatively cheap: a good query planner would first locally compute weighted averages, and only transfer those results via the network to compute the actual averages. If that is not the case, the query can simply be broken into two parts with an intermediate table for the weighted averages.

INSERT INTO cluster_centers
    i.last_iteration + 1 AS iteration,
    avg(p.x1) AS x1,
    avg(p.x2) AS x2
  FROM points p
  JOIN clusters c USING (point_id)
  CROSS JOIN (SELECT max(iteration) AS last_iteration FROM kmeans.cluster_centers) i
  GROUP BY iteration, c.cluster_id;

And that was one iteration of k-means clustering in SQL. Note that we store the cluster centers of each iteration, so that we can determine if the algorithm converged. Technically the algorithm already converges when the cluster assignments don't change in step 2, but here it is easier to check if the cluster centers did not change.

Now while it's possible to implement the loop around steps 2 and 3 with a recursive query in some databases, it is probably easier to wrap them around a few lines of iterative code:

db.connect() # Make sure not every run call has to wait for a new connection.
A working example of this can be found on GitHub[2]. While it's often not a good idea to run a lot of queries from a code loop, this can be reasonable here assuming that there won't be too many iterations and the computations in each iteration will be comparably slow.

Some additional notes: