osc.models.embeds.torch_kmeans_cosine
- torch_kmeans_cosine(x, num_clusters, seed=None, max_iters=10, tol=1e-08)[source]
Batched k-means for PyTorch with cosine distance.
- Parameters
x (
Tensor
) – tensor containingB
batches ofN
C
-dimensional tensors each, shape [B N C]. The tensors should be L2-normalized along theC
dimension.num_clusters (
int
) – number of clustersK
seed (
Optional
[int
]) – int seed for centroid initialization, leave empty to use numpy default generatormax_iters – max number of iterations
tol – tolerance for early termination
- Returns
A
[B K C]
tensor of centroids.