.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/2000_cluster/kmeans.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_2000_cluster_kmeans.py: K-means Clustering ============================ CR-Sparse includes a K-means implementation as part of its sparse subspace clustering module. .. GENERATED FROM PYTHON SOURCE LINES 12-13 Configure JAX to work with 64-bit floating point precision. .. GENERATED FROM PYTHON SOURCE LINES 13-17 .. code-block:: default from jax.config import config config.update("jax_enable_x64", True) .. GENERATED FROM PYTHON SOURCE LINES 18-19 Let's import necessary libraries .. GENERATED FROM PYTHON SOURCE LINES 19-30 .. code-block:: default from jax import random import jax.numpy as jnp import cr.sparse as crs # vector quantization import cr.sparse.cluster.vq as vq # Plotting import matplotlib.pyplot as plt # Some PRNGKeys for later use key = random.PRNGKey(0) keys = random.split(key, 10) .. GENERATED FROM PYTHON SOURCE LINES 31-32 Prepare sample data .. GENERATED FROM PYTHON SOURCE LINES 32-53 .. code-block:: default # Number of points for each cluster pts = 50 # Mean vector for first cluster mu_a = jnp.array([0, 0]) # Covariance matrix for first cluster cov_a = jnp.array([[4, 1], [1, 4]]) # Sampled points for the first cluster a = random.multivariate_normal(keys[0], mu_a, cov_a, shape=(pts,)) # Mean vector for second cluster mu_b = jnp.array([30, 10]) # Covariance matrix for second cluster cov_b = jnp.array([[10, 2], [2, 1]]) # Sampled points for the second cluster b = random.multivariate_normal(keys[1], mu_b, cov_b, shape=(pts,)) # combined points features = jnp.concatenate((a, b)) # plot the points plt.scatter(features[:, 0], features[:, 1]) .. image-sg:: /gallery/2000_cluster/images/sphx_glr_kmeans_001.png :alt: kmeans :srcset: /gallery/2000_cluster/images/sphx_glr_kmeans_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 54-56 K-means clustering ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 56-70 .. code-block:: default # number of clusters k=2 # Perform K-means clustering result = vq.kmeans_jit(keys[3], features, k) centroids = result.centroids assignment = result.assignment for i in range(k): # points for the k-th cluster cluster = features[assignment == i] plt.plot(cluster[:,0], cluster[:,1], "o", alpha=0.4) # plot the centroids plt.scatter(centroids[:, 0], centroids[:, 1], c='r') .. image-sg:: /gallery/2000_cluster/images/sphx_glr_kmeans_002.png :alt: kmeans :srcset: /gallery/2000_cluster/images/sphx_glr_kmeans_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 71-73 Let's print the assignment to verify that points have indeed been assigned to respective clusters .. GENERATED FROM PYTHON SOURCE LINES 73-74 .. code-block:: default print(assignment) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 75-76 Let's print the number of points for each cluster .. GENERATED FROM PYTHON SOURCE LINES 76-77 .. code-block:: default print(vq.assignment_counts(assignment, k)) .. rst-class:: sphx-glr-script-out .. code-block:: none [[50] [50]] .. GENERATED FROM PYTHON SOURCE LINES 78-79 Let's print the number of iterations taken to converge .. GENERATED FROM PYTHON SOURCE LINES 79-81 .. code-block:: default print(result.iterations) .. rst-class:: sphx-glr-script-out .. code-block:: none 2 .. GENERATED FROM PYTHON SOURCE LINES 82-83 If we have points and centroids, we can compute the assignments .. GENERATED FROM PYTHON SOURCE LINES 83-86 .. code-block:: default assignment, distances = vq.find_assignment_jit(features, centroids) print(assignment) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 87-88 For any new point, it is easy to find the index of nearest centroid .. GENERATED FROM PYTHON SOURCE LINES 88-90 .. code-block:: default pt = jnp.array([1, 4]) idx = vq.find_nearest_jit(pt, centroids) print(centroids[idx]) .. rst-class:: sphx-glr-script-out .. code-block:: none [-0.39618243 -0.09914168] .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.910 seconds) .. _sphx_glr_download_gallery_2000_cluster_kmeans.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: kmeans.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: kmeans.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_