Source code for cr.sparse._src.cluster.kmeans

# Copyright 2021 CR-Suite Development Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

The implementation in this file is based on the example provided by
Sabrina J. Mielke

from typing import NamedTuple

from jax import lax, jit, vmap, random
import jax.numpy as jnp
from jax.numpy.linalg import norm

[docs]class KMeansState(NamedTuple): """The state for K-means algorithm """ centroids: jnp.ndarray """Current set of centroids""" assignment: jnp.ndarray """Current assignment of points to centroids""" distortion : float """ Current mean distance""" prev_distortion : float """ Previous mean distance""" iterations: int """The number of iterations it took to complete"""
[docs]class KMeansSolution(NamedTuple): """The solution for K-means algorithm """ centroids: jnp.ndarray """Current set of centroids""" assignment: jnp.ndarray """Current assignment of points to centroids""" distortion : float """ Current mean distance""" key: jnp.ndarray """ The PRNG key seed for the k-means run with least distortion""" iterations: int """The number of iterations it took to complete"""
[docs]def find_nearest(point, centroids): """Returns the index of the nearest centroid for a specific point Args: point (jax.numpy.ndarray) : A specific point centroids (jax.numpy.ndarray) : An array of centroids Returns: (int) : The index of the nearest centroid """ return jnp.argmin(vmap(norm)(centroids - point))
find_nearest_jit = jit(find_nearest)
[docs]def find_assignment(points, centroids): """Finds the assignment of each point to a specific centroid Args: points (jax.numpy.ndarray) : Each row of the points matrix is a point. centroids (jax.numpy.ndarray) : An array of centroids Returns: (jax.numpy.ndarray, jax.numpy.ndarray): A tuple consisting of #. An assignment array of each point to a cluster #. Distance of each point from corresponding cluster centroid """ assignment = vmap(lambda point: find_nearest(point, centroids))(points) errors = centroids[assignment, :] - points distances = vmap(norm)(errors) return assignment, distances
find_assignment_jit = jit(find_assignment) def assignment_counts(assignment, k): """Returns the number of points in each cluster based on the current assignment If a cluster has no points, we return 1. """ return ((assignment[jnp.newaxis, :] == jnp.arange(k)[:, jnp.newaxis]) .sum(axis=1, keepdims=True) .clip(min=1))
[docs]def find_new_centroids(assignment, points, k): """Finds new centroids based on current assignment Args: assignment (jax.numpy.ndarray) : current assignment of each point to a specific cluster points (jax.numpy.ndarray) : Each row of the points matrix is a point. k (int): The number of clusters """ counts = assignment_counts(assignment, k) new_centroids = jnp.sum( jnp.where( # axes: (data points, clusters, data dimension) assignment[:, jnp.newaxis, jnp.newaxis] \ == jnp.arange(k)[jnp.newaxis, :, jnp.newaxis], points[:, jnp.newaxis, :], 0., ), axis=0, ) / counts return new_centroids
find_new_centroids_jit = jit(find_new_centroids, static_argnums=(2,))
[docs]def kmeans_with_seed(key, points, k, thresh=1e-5, max_iters=100): """Runs the k-means algorithm for a specific random initialization Args: key: a PRNG key used as the random key for choosing initial centroids points (jax.numpy.ndarray): Each row of the points matrix is a point. k (int): The number of clusters thresh (float): Convergence threshold on change in distortion max_iters (int): Maximum number of iterations for k-means algorithm Returns: (KMeansState): A named tuple consisting of: centroids for each cluster, assignment of each point to a cluster, current distorition, previous distortion, number of iterations for convergence. """ # number of points n = points.shape[0] def init(): # select k points as initial centroids randomly indices = random.permutation(key, jnp.arange(n))[:k] # the initial centroids centroids = points[indices, :] # assign all points to centroids and compute distances assignment, distances = find_assignment(points, centroids) distortion = jnp.mean(distances) # algorithm state return KMeansState(centroids=centroids, assignment=assignment, distortion=distortion, prev_distortion=jnp.inf, iterations=0) def body(state): # update centroids centroids = find_new_centroids(state.assignment, points, k) # update assignment assignment, distances = find_assignment(points, centroids) # mean distance distortion = jnp.mean(distances) # algorithm state return KMeansState(centroids=centroids, assignment=assignment, distortion=distortion, prev_distortion=state.distortion, iterations=state.iterations+1) def cond(state): # check if the mean distance has updated enough gap = state.prev_distortion - state.distortion # print(state.prev_distortion, state.distortion, gap, thresh, gap > thresh) return jnp.logical_and(gap > thresh, state.iterations < max_iters) # state = init() # while cond(state): # state = body(state) state = lax.while_loop(cond, body, init()) return state
kmeans_with_seed_jit = jit(kmeans_with_seed, static_argnums=(2,3))
[docs]def kmeans(key, points, k, iter=20, thresh=1e-5, max_iters=100): r"""Clusters points using k-means algorithm Args: key: a PRNG key used as the random key points (jax.numpy.ndarray): Each row of the points matrix is a point. From the statistical point of view, each row is an observation vector and each column is a feature. k (int): The number of clusters iter (int): The number of times k-means will be restarted with different seeds. The result with least amount of distortion is returned. thresh (float): Convergence threshold on change in distortion max_iters (int): Maximum number of iterations for each replicate of k-means algorithm Returns: (KMeansSolution): A named tuple consisting of: * centroids : centroid for each cluster * assignment: assignment of each point to a cluster * distortion: distortion after current assignment * key: The PRNG key seed for the k-means run with the least distortion * iterations: number of iterations taken in convergence Let the k centroids be :math:`m_1, m_2, \dots, m_k`. Let the n points be :math:`x_1, x_2, \dots, x_n`. Let the assignment of i-th point to j-th cluster be given by :math:`a_1, a_2, \dots, a_n` where :math:`1 \leq a_i = j \leq k`. Then the distance of i-th point from its centroid is given by: .. math:: d_i = \| x_i - m_{a_i} \|_2 The distortion is given by the mean of all the distances. """ # keys for each restart of kmeans algorithm keys = random.split(key, iter) # individual run of k-means algorithm kmeans_core = lambda key: kmeans_with_seed(key, points, k, thresh=thresh, max_iters=max_iters) # Run all restarts of kmeans using vmap results = vmap(kmeans_core, 0, 0)(keys) # Find the run with the least distortion i = jnp.argmin(results.distortion) return KMeansSolution(centroids=results.centroids[i], assignment=results.assignment[i], distortion=results.distortion[i], key=keys[i], iterations=results.iterations[i])
kmeans_jit = jit(kmeans, static_argnums=(2,3,4,5))