osc.models.embeds.torch_kmeans_cosine
- torch_kmeans_cosine(x, num_clusters, seed=None, max_iters=10, tol=1e-08)[source]
Batched k-means for PyTorch with cosine distance.
- Parameters
x (
Tensor) – tensor containingBbatches ofNC-dimensional tensors each, shape [B N C]. The tensors should be L2-normalized along theCdimension.num_clusters (
int) – number of clustersKseed (
Optional[int]) – int seed for centroid initialization, leave empty to use numpy default generatormax_iters – max number of iterations
tol – tolerance for early termination
- Returns
A
[B K C]tensor of centroids.