cr.sparse.pdist_sqr_l2_cw

cr.sparse.pdist_sqr_l2_cw(A)[source]

Computes the pairwise squared distances between points in A where each point is a column vector

Parameters

A (jax.numpy.ndarray) – A set of N K-dimensional points (column-wise)

Returns

An NxN matrix D of squared euclidean distances between points in A

Return type

(jax.numpy.ndarray)

  • Let the ambient space of points be \(\mathbb{F}^K\).

  • \(A\) contains the points \(a_i\) with \(1 \leq i \leq N\) and each point maps to a column of \(A\).

Then the distance matrix \(D\) is of size \(N \times N\) and consists of:

(1)\[d_{i, j} = \| a_i - a_j \|_2^2 = \langle a_i - a_j , a_i - a_j \rangle\]