osc.models.embeds.torch_kmeans_cosine

torch_kmeans_cosine(x, num_clusters, seed=None, max_iters=10, tol=1e-08)[source]

Batched k-means for PyTorch with cosine distance.

Parameters
  • x (Tensor) – tensor containing B batches of N C-dimensional tensors each, shape [B N C]. The tensors should be L2-normalized along the C dimension.

  • num_clusters (int) – number of clusters K

  • seed (Optional[int]) – int seed for centroid initialization, leave empty to use numpy default generator

  • max_iters – max number of iterations

  • tol – tolerance for early termination

Returns

A [B K C] tensor of centroids.