{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Sparse Subspace Clustering - OMP\n\nThis example demonstrates the sparse subspace clustering algorithm via orthogonal matching pursuit.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Configure JAX to work with 64-bit floating point precision. \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from jax import random\nimport jax.numpy as jnp\nimport cr.nimble as cnb\nimport cr.sparse.data as crdata\nimport cr.nimble as cnb\nimport cr.nimble.subspaces\n# clustering related\nimport cr.sparse.cluster.spectral as spectral\nimport cr.sparse.cluster.ssc as ssc\n# Plotting\nimport matplotlib.pyplot as plt\n# evaluation\nimport sklearn.metrics\n# Some PRNGKeys for later use\nkey = random.PRNGKey(0)\nkeys = random.split(key, 10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Problem configuration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# ambient space dimension\nN = 40\n# Subspace dimension\nD = 5\n# Number of subspaces\nK = 5\n# Number of points per subspace\nS = 50"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Test data preparation\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Construct orthonormal bases for K subspaces\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "bases = crdata.random_subspaces_jit(keys[0], N, D, K)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Measure angles between subspaces in degrees\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "angles = cnb.subspaces.smallest_principal_angles_deg(bases)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Print the minimum angle between any pair of subspaces\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(cnb.off_diagonal_min(angles))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate uniformly distributed points on each subspace\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = crdata.uniform_points_on_subspaces(keys[1], bases, S)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Assign true labels to each point to corresponding \nsubspace index\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "true_labels = jnp.repeat(jnp.arange(K), S)\nprint(true_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Total number of data points\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "total = len(true_labels)\nprint(total)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Sparse Subspace Clustering Algorithm\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Build representation of each point in terms of other points\nby using Orthogonal Matching Pursuit algorithm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Z, I, R = ssc.build_representation_omp_jit(X, D)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Combine values and indices to form full representation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Z_full = ssc.sparse_to_full_rep(Z, I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Build the affinity matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "affinity = abs(Z_full) + abs(Z_full).T\nplt.imshow(affinity, cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Perform the spectral clustering on the affinity matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "res = spectral.unnormalized_k_jit(keys[2], affinity, K)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Predicted cluster labels\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pred_labels = res.assignment\nprint(pred_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Evaluate the clustering performance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(sklearn.metrics.rand_score(true_labels, pred_labels))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## SSC-OMP with shuffled data\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Choose a random permutation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "perm = random.permutation(keys[3], total)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Randomly permute the data points\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = X[:, perm]\n# Permute the true labels accordingly\ntrue_labels = true_labels[perm]\nprint(true_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Build representation of each point in terms of other points\nby using Orthogonal Matching Pursuit algorithm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Z, I, R = ssc.build_representation_omp_jit(X, D)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Combine values and indices to form full representation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Z_full = ssc.sparse_to_full_rep(Z, I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Build the affinity matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "affinity = abs(Z_full) + abs(Z_full).T\nplt.imshow(affinity, cmap='gray')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Perform the spectral clustering on the affinity matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "res = spectral.unnormalized_k_jit(keys[4], affinity, K)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Predicted cluster labels\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pred_labels = res.assignment\nprint(pred_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Evaluate the clustering performance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(sklearn.metrics.rand_score(true_labels, pred_labels))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}