cr.sparse.cluster.vq.kmeans_jit

cr.sparse.cluster.vq.kmeans_jit(key, points, k, iter=20, thresh=1e-05, max_iters=100)

Clusters points using k-means algorithm

Parameters
  • key – a PRNG key used as the random key

  • points (jax.numpy.ndarray) – Each row of the points matrix is a point. From the statistical point of view, each row is an observation vector and each column is a feature.

  • k (int) – The number of clusters

  • iter (int) – The number of times k-means will be restarted with different seeds. The result with least amount of distortion is returned.

  • thresh (float) – Convergence threshold on change in distortion

  • max_iters (int) – Maximum number of iterations for each replicate of k-means algorithm

Returns

A named tuple consisting of:

  • centroids : centroid for each cluster

  • assignment: assignment of each point to a cluster

  • distortion: distortion after current assignment

  • key: The PRNG key seed for the k-means run with the least distortion

  • iterations: number of iterations taken in convergence

Return type

(KMeansSolution)

Let the k centroids be \(m_1, m_2, \dots, m_k\). Let the n points be \(x_1, x_2, \dots, x_n\). Let the assignment of i-th point to j-th cluster be given by \(a_1, a_2, \dots, a_n\) where \(1 \leq a_i = j \leq k\).

Then the distance of i-th point from its centroid is given by:

(1)\[d_i = \| x_i - m_{a_i} \|_2\]

The distortion is given by the mean of all the distances.