Sparse Subspace Clustering - OMP

This example demonstrates the sparse subspace clustering algorithm via orthogonal matching pursuit.

Configure JAX to work with 64-bit floating point precision.

from jax.config import config
config.update("jax_enable_x64", True)

Let’s import necessary libraries

from jax import random
import jax.numpy as jnp
import cr.nimble as cnb
import cr.sparse.data as crdata
import cr.nimble as cnb
import cr.nimble.subspaces
# clustering related
import cr.sparse.cluster.spectral as spectral
import cr.sparse.cluster.ssc as ssc
# Plotting
import matplotlib.pyplot as plt
# evaluation
import sklearn.metrics
# Some PRNGKeys for later use
key = random.PRNGKey(0)
keys = random.split(key, 10)

Problem configuration

# ambient space dimension
N = 40
# Subspace dimension
D = 5
# Number of subspaces
K = 5
# Number of points per subspace
S = 50

Test data preparation

Construct orthonormal bases for K subspaces

bases = crdata.random_subspaces_jit(keys[0], N, D, K)

Measure angles between subspaces in degrees

angles = cnb.subspaces.smallest_principal_angles_deg(bases)

Print the minimum angle between any pair of subspaces

print(cnb.off_diagonal_min(angles))
47.44974475121892

Generate uniformly distributed points on each subspace

X = crdata.uniform_points_on_subspaces(keys[1], bases, S)

Assign true labels to each point to corresponding subspace index

true_labels = jnp.repeat(jnp.arange(K), S)
print(true_labels)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]

Total number of data points

total = len(true_labels)
print(total)
250

Sparse Subspace Clustering Algorithm

Build representation of each point in terms of other points by using Orthogonal Matching Pursuit algorithm

Z, I, R = ssc.build_representation_omp_jit(X, D)

Combine values and indices to form full representation

Z_full = ssc.sparse_to_full_rep(Z, I)

Build the affinity matrix

affinity = abs(Z_full) + abs(Z_full).T
plt.imshow(affinity, cmap='gray')
ssc omp
<matplotlib.image.AxesImage object at 0x7f27aa75b220>

Perform the spectral clustering on the affinity matrix

res = spectral.unnormalized_k_jit(keys[2], affinity, K)

Predicted cluster labels

pred_labels = res.assignment
print(pred_labels)
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

Evaluate the clustering performance

print(sklearn.metrics.rand_score(true_labels, pred_labels))
1.0

SSC-OMP with shuffled data

Choose a random permutation

perm = random.permutation(keys[3], total)

Randomly permute the data points

X = X[:, perm]
# Permute the true labels accordingly
true_labels = true_labels[perm]
print(true_labels)
[0 3 2 2 1 4 0 0 3 3 0 3 4 2 4 3 3 0 2 1 3 4 1 3 0 2 4 1 0 4 2 2 3 0 2 4 3
 3 0 0 1 0 4 3 1 1 4 4 1 1 2 4 1 0 3 4 1 4 0 1 1 2 0 1 3 0 3 4 0 0 4 1 1 1
 3 2 4 0 2 3 3 1 2 3 2 1 1 4 3 4 2 0 0 4 4 1 4 0 2 2 0 4 4 2 0 1 1 1 4 1 2
 2 4 2 0 0 4 2 1 0 2 3 3 3 1 3 4 3 0 4 3 4 2 1 3 3 4 3 4 3 1 4 4 2 0 0 1 1
 1 1 1 3 3 4 2 1 2 1 4 4 3 2 3 0 0 1 4 1 0 4 3 1 1 2 3 1 2 0 0 2 4 2 2 2 2
 0 3 3 1 0 3 2 0 4 0 0 4 4 0 1 0 1 1 4 3 3 1 3 2 2 3 0 4 2 4 3 2 3 0 3 2 3
 3 0 4 2 0 2 2 4 0 2 1 2 1 0 3 0 3 2 1 0 2 1 4 4 2 0 2 4]

Build representation of each point in terms of other points by using Orthogonal Matching Pursuit algorithm

Z, I, R = ssc.build_representation_omp_jit(X, D)

Combine values and indices to form full representation

Z_full = ssc.sparse_to_full_rep(Z, I)

Build the affinity matrix

affinity = abs(Z_full) + abs(Z_full).T
plt.imshow(affinity, cmap='gray')
ssc omp
<matplotlib.image.AxesImage object at 0x7f27a3a565e0>

Perform the spectral clustering on the affinity matrix

res = spectral.unnormalized_k_jit(keys[4], affinity, K)

Predicted cluster labels

pred_labels = res.assignment
print(pred_labels)
[3 0 1 1 2 4 3 3 0 0 3 0 4 1 4 0 0 3 1 2 0 4 2 0 3 1 4 2 3 4 1 1 0 3 1 4 0
 0 3 3 2 3 4 0 2 2 4 4 2 2 1 4 2 3 0 4 2 4 3 2 2 1 3 2 0 3 0 4 3 3 4 2 2 2
 0 1 4 3 1 0 0 2 1 0 1 2 2 4 0 4 1 3 3 4 4 2 4 3 1 1 3 4 4 1 3 2 2 2 4 2 1
 1 4 1 3 3 4 1 2 3 1 0 0 0 2 0 4 0 3 4 0 4 1 2 0 0 4 0 4 0 2 4 4 1 3 3 2 2
 2 2 2 0 0 4 1 2 1 2 4 4 0 1 0 3 3 2 4 2 3 4 0 2 2 1 0 2 1 3 3 1 4 1 1 1 1
 3 0 0 2 3 0 1 3 4 3 3 4 4 3 2 3 2 2 4 0 0 2 0 1 1 0 3 4 1 4 0 1 0 3 0 1 0
 0 3 4 1 3 1 1 4 3 1 2 1 2 3 0 3 0 1 2 3 1 2 4 4 1 3 1 4]

Evaluate the clustering performance

print(sklearn.metrics.rand_score(true_labels, pred_labels))
1.0

Total running time of the script: (0 minutes 5.554 seconds)

Gallery generated by Sphinx-Gallery