Distributed k-means clustering in SQL
Clustering data points into groups is a pretty common use case. With an increasing number of dimensions this quickly becomes infeasible to do manually. A first step to compute a clustering may be the relatively simple k-means algorithm (which technically is a heuristic for the k-means problem)[1]. While it does not always deliver good results, it often works quite well.
The interactive demo below gives an idea how the algorithm
works. As one of the caveats, the number of
clusters k
needs to be specified as parameter
for the algorithm. The algorithm then follows these steps:
- Pick
k
random points as cluster centers. - Assign each point to the cluster of its nearest cluster center.
- Update the cluster centers to be the average of all points now assigned to the cluster.
Clusters: | ||||
Mean x | Variance x | Mean y | Variance y | Points |
While the loop around steps 2 and 3 of the algorithm is naturally a bit problematic for an implementation in SQL, each of the steps is very simple to implement. And in a distributed database the computations should parallelize pretty much automatically if the data is distributed in a way that allows most computations to be local.
Assuming there is a large number of data points, we make sure the table for them is evenly distributed in some deterministic way.
CREATE TABLE points (
point_id INTEGER,
x1 FLOAT,
x2 FLOAT
)
DISTRIBUTED BY HASH(point_id);
We can then initialize the cluster centers by
picking k
random points from the original table
as cluster centers of the first iteration (step 1 of the
algorithm). Assuming that the number of clusters is
relatively small, this table should be distributed with
relatively low cost to all nodes.
CREATE TABLE cluster_centers AS (
iteration INTEGER,
cluster_id INTEGER,
x1 FLOAT,
x2 FLOAT
)
DISTRIBUTED ON ALL NODES;
INSERT INTO cluster_centers
SELECT
0 AS iteration,
row_number() OVER () AS cluster_id,
x1,
x2
FROM points
ORDER BY random()
LIMIT 3; -- k = 3
For step 2 of the algorithm, we compute distances for each
point to each cluster center and pick the cluster with the
smallest distance. Since the cluster_centers
table is distributed to all nodes, this should be an
entirely local operation.
CREATE TABLE clusters (
point_id INTEGER,
cluster_id INTEGER
)
DISTRIBUTED BY HASH(point_id);
INSERT INTO clusters
SELECT
point_id,
cluster_id
FROM (
SELECT
p.point_id,
c.cluster_id,
row_number() OVER (PARTITION BY p.point_id ORDER BY sqrt(((p.x1 - c.x1)^2 * (p.x2 - c.x2)^2)) ASC) AS rank
FROM points p
CROSS JOIN cluster_centers c
WHERE c.iteration = (SELECT max(iteration) FROM cluster_centers)
) rankings
WHERE rank = 1;
Updating the cluster centers in step 3 of the algorithm is now a simple average aggregation grouped by the previously calculated clusters. This operation is not entirely local, but should still be relatively cheap: a good query planner would first locally compute weighted averages, and only transfer those results via the network to compute the actual averages. If that is not the case, the query can simply be broken into two parts with an intermediate table for the weighted averages.
INSERT INTO cluster_centers
SELECT
i.last_iteration + 1 AS iteration,
c.cluster_id,
avg(p.x1) AS x1,
avg(p.x2) AS x2
FROM points p
JOIN clusters c USING (point_id)
CROSS JOIN (SELECT max(iteration) AS last_iteration FROM kmeans.cluster_centers) i
GROUP BY iteration, c.cluster_id;
And that was one iteration of k-means clustering in SQL. Note that we store the cluster centers of each iteration, so that we can determine if the algorithm converged. Technically the algorithm already converges when the cluster assignments don't change in step 2, but here it is easier to check if the cluster centers did not change.
Now while it's possible to implement the loop around steps 2 and 3 with a recursive query in some databases, it is probably easier to wrap them around a few lines of iterative code:
db.connect() # Make sure not every run call has to wait for a new connection.
db.run(CREATE_TABLES)
db.run(INIT_CENTERS)
while db.run(CLUSTER_CENTERS_CHANGED).fetch():
db.run(ASSIGN_CLUSTERS)
db.run(RECOMPUTE_CLUSTER_CENTERS)
db.close_connection()
A working example of this can be found on
GitHub[2]. While it's often
not a good idea to run a lot of queries from a code loop,
this can be reasonable here assuming that there won't be too
many iterations and the computations in each iteration will
be comparably slow.
Some additional notes:
- While the algorithm is simple in SQL, this is also true for most actual programming languages. The implementation here is not especially efficient, not least because it needs an additional pass over the data to update the cluster centers in each iteration. Still it can be a nice quickstart in case your data happens to be in the database.
-
The queries above of course become more annoying if the
data points have more than two dimensions. Depending on
the database it may be best to auto-generate the
queries, use arrays, or switch the table design
to
(point_id, dimension_id, value)
tuples. - How to determine the number of clusters is an important question in practice, which the k-means algorithm doesn't really answer.
- Since the algorithm is a heuristic, it does not always find the optimal solution. One way to improve the heuristic is to spend a bit more effort on choosing the initial cluster centers, instead of picking them randomly.
- An interesting extension for the k-means algorithm is to add knowledge via constraints, which define for some data point combinations if they should be within the same cluster or not [3].
References: