gmm_gpu.gmm
Provides a GMM class for fitting multiple instances of Gaussian Mixture Models .
This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.
If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe Pomegranate would be a better option.
Example usage:
Import pytorch and the GMM class
>>> from gmm_gpu.gmm import GMM
>>> import torch
Generate some test data: We create a batch of 1000 instances, each with 200 random points. Half of the points are sampled from distribution centered at the origin (0, 0) and the other half from a distribution centered at (1.5, 1.5).
>>> X1 = torch.randn(1000, 100, 2)
>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
>>> X = torch.cat([X1, X2], dim=1)
Fit the model
>>> gmm = GMM(n_components=2, device='cuda')
>>> gmm.fit(X)
Predict the components: This will return a matrix with shape (1000, 200) where each value is the predicted component for the point.
>>> gmm.predict(X)
1""" 2Provides a GMM class for fitting multiple instances of `Gaussian Mixture Models <https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model>`_. 3 4This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. 5You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then 6send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points. 7 8If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe `Pomegranate <https://github.com/jmschrei/pomegranate>`_ would be a better option. 9 10### Example usage: 11Import pytorch and the GMM class 12>>> from gmm_gpu.gmm import GMM 13>>> import torch 14 15Generate some test data: 16We create a batch of 1000 instances, each 17with 200 random points. Half of the points 18are sampled from distribution centered at 19the origin (0, 0) and the other half from 20a distribution centered at (1.5, 1.5). 21>>> X1 = torch.randn(1000, 100, 2) 22>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5]) 23>>> X = torch.cat([X1, X2], dim=1) 24 25Fit the model 26>>> gmm = GMM(n_components=2, device='cuda') 27>>> gmm.fit(X) 28 29Predict the components: 30This will return a matrix with shape (1000, 200) where 31each value is the predicted component for the point. 32>>> gmm.predict(X) 33""" 34 35import torch 36import numpy as np 37 38 39class GMM: 40 def __init__(self, 41 n_components, 42 max_iter=100, 43 device='cuda', 44 tol=0.001, 45 reg_covar=1e-6, 46 means_init=None, 47 weights_init=None, 48 precisions_init=None, 49 dtype=torch.float32, 50 random_seed=None): 51 """ 52 Initialize a Gaussian Mixture Models instance to fit. 53 54 Parameters 55 ---------- 56 n_components : int 57 Number of components (gaussians) in the model. 58 max_iter : int 59 Maximum number of EM iterations to perform. 60 device : torch.device 61 Which device to be used for the computations 62 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 63 tol : float 64 The convergence threshold. 65 reg_covar : float 66 Non-negative regularization added to the diagonal of covariance. 67 Allows to assure that the covariance matrices are all positive. 68 means_init : torch.tensor 69 User provided initialization means for all instances. The 70 tensor should have shape (Batch, Components, Dimensions). 71 If None (default) the means are going to be initialized 72 with modified kmeans++ and then refined with kmeans. 73 weights_init : torch.tensor 74 The user-provided initial weights. The tensor should have shape 75 (Batch, Components). If it is None, weights are initialized 76 depending on the kmeans++ & kmeans initialization. 77 precisions_init : torch.tensor 78 The user-provided initial precisions (inverse of the covariance matrices). 79 The tensor should have shape (Batch, Components, Dimension, Dimension). 80 If it is None, precisions are initialized depending on the kmeans++ & kmeans 81 initialization. 82 dtype : torch.dtype 83 Data type that will be used in the GMM instance. 84 random_seed : int 85 Controls the random seed that will be used 86 when initializing the model parameters. 87 """ 88 self._n_components = n_components 89 self._max_iter = max_iter 90 self._device = device 91 self._tol = tol 92 self._reg_covar = reg_covar 93 self._means_init = means_init 94 self._weights_init = weights_init 95 self._precisions_init = precisions_init 96 self._dtype = dtype 97 self._rand_generator = torch.Generator(device=device) 98 if random_seed: 99 self._rand_seed = random_seed 100 self._rand_generator.manual_seed(random_seed) 101 else: 102 self._rand_seed = None 103 104 105 def fit(self, X): 106 """ 107 Fit the GMM on the given tensor data. 108 109 Parameters 110 ---------- 111 X : torch.tensor 112 A tensor with shape (Batch, N-points, Dimensions) 113 """ 114 X = X.to(device=self._device, dtype=self._dtype) 115 116 B, N, D = X.shape 117 118 self._init_parameters(X) 119 component_mask = self._init_clusters(X) 120 121 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 122 for k in range(self._n_components): 123 r[:, :, k][component_mask == k] = 1 124 125 # This gives us the amount of points per component 126 # for each instance in the batch. It's necessary 127 # in order to handle missing points (with nan values). 128 N_actual = r.nansum(1) 129 N_actual_total = N_actual.sum(1) 130 131 converged = torch.full((B,), False, device=self._device) 132 133 # If at least one of the parameters is missing 134 # we calculate all parameters with the M-step. 135 if (self._means_init is None or 136 self._weights_init is None or 137 self._precisions_init is None): 138 self._m_step(X, r, N_actual, N_actual_total, converged) 139 140 # If any of the parameters have been provided by the 141 # user, we overwrite it with the provided value. 142 if self._means_init is not None: 143 self.means = [self._means_init[:, c, :] 144 for c in range(self._n_components)] 145 if self._weights_init is not None: 146 self._pi = [self._weights_init[:, c] 147 for c in range(self._n_components)] 148 if self._precisions_init is not None: 149 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 150 for c in range(self._n_components)] 151 152 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 153 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 154 155 iteration = 1 156 while iteration <= self._max_iter and not converged.all(): 157 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 158 159 # === E-STEP === 160 161 for k in range(self._n_components): 162 r[~converged, :, k] = torch.add( 163 _estimate_gaussian_prob( 164 X[~converged], 165 self.means[k][~converged], 166 self._precisions_cholesky[k][~converged], 167 self._dtype).log(), 168 self._pi[k][~converged].unsqueeze(1).log() 169 ) 170 log_prob_norm = r[~converged].logsumexp(2) 171 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 172 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 173 N_actual = r.nansum(1) 174 175 # If we have less than 2 points in a component it produces 176 # bad covariance matrices. Hence, we stop the iterations 177 # for the affected instances and continue with the rest. 178 unprocessable_instances = (N_actual < 2).any(1) 179 converged[unprocessable_instances] = True 180 181 # === M-STEP === 182 183 self._m_step(X, r, N_actual, N_actual_total, converged) 184 185 change = mean_log_prob_norm - prev_mean_log_prob_norm 186 187 # If the change for some instances in the batch 188 # are small enough, we mark those instances as 189 # converged and do not process them anymore. 190 small_change = change.abs() < self._tol 191 newly_converged = small_change & ~converged 192 converged[newly_converged] = True 193 self.convergence_iters[newly_converged] = iteration 194 195 iteration += 1 196 197 198 def predict_proba(self, X, force_cpu_result=True): 199 """ 200 Estimate the components' density for all samples 201 in all instances. 202 203 Parameters 204 ---------- 205 X : torch.tensor 206 A tensor with shape (Batch, N-points, Dimensions) 207 force_cpu_result : bool 208 Make sure that the resulting tensor is loaded on 209 the CPU regardless of the device used for the 210 computations (default: True). 211 212 Returns 213 ---------- 214 torch.tensor 215 tensor of shape (B, N, n_clusters) with probabilities. 216 The values at positions [I, S, :] will be the probabilities 217 of sample S in instance I to be assigned to each component. 218 """ 219 X = X.to(device=self._device, dtype=self._dtype) 220 B, N, D = X.shape 221 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 222 for k in range(self._n_components): 223 # Calculate weighted log probabilities 224 log_probs[:, :, k] = torch.add( 225 self._pi[k].log().unsqueeze(1), 226 _estimate_gaussian_prob(X, 227 self.means[k], 228 self._precisions_cholesky[k], 229 self._dtype).log()) 230 log_prob_norm = log_probs.logsumexp(2) 231 log_resp = log_probs - log_prob_norm.unsqueeze(2) 232 233 if force_cpu_result: 234 return log_resp.exp().cpu() 235 return log_resp.exp() 236 237 238 def predict(self, X, force_cpu_result=True): 239 """ 240 Predict the component assignment for the given tensor data. 241 242 Parameters 243 ---------- 244 X : torch.tensor 245 A tensor with shape (Batch, N-points, Dimensions) 246 force_cpu_result : bool 247 Make sure that the resulting tensor is loaded on 248 the CPU regardless of the device used for the 249 computations (default: True). 250 251 Returns 252 ---------- 253 torch.tensor 254 tensor of shape (B, N) with component ids as values. 255 """ 256 X = X.to(device=self._device, dtype=self._dtype) 257 B, N, D = X.shape 258 probs = torch.zeros(B, N, self._n_components, device=X.device) 259 for k in range(self._n_components): 260 probs[:, :, k] = _estimate_gaussian_prob(X, 261 self.means[k], 262 self._precisions_cholesky[k], 263 self._dtype) 264 if force_cpu_result: 265 torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu() 266 return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)) 267 268 269 def score_samples(self, X, force_cpu_result=True): 270 """ 271 Compute the log-likelihood of each point across all instances in the batch. 272 273 Parameters 274 ---------- 275 X : torch.tensor 276 A tensor with shape (Batch, N-points, Dimensions) 277 force_cpu_result : bool 278 Make sure that the resulting tensor is loaded on 279 the CPU regardless of the device used for the 280 computations (default: True). 281 282 Returns 283 ---------- 284 torch.tensor 285 tensor of shape (B, N) with the score for each point in the batch. 286 """ 287 X = X.to(device=self._device, dtype=self._dtype) 288 B, N, D = X.shape 289 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 290 for k in range(self._n_components): 291 # Calculate weighted log probabilities 292 log_probs[:, :, k] = torch.add( 293 self._pi[k].log().unsqueeze(1), 294 _estimate_gaussian_prob(X, 295 self.means[k], 296 self._precisions_cholesky[k], 297 self._dtype).log()) 298 if force_cpu_result: 299 return log_probs.logsumexp(2).cpu() 300 return log_probs.logsumexp(2) 301 302 303 def score(self, X, force_cpu_result=True): 304 """ 305 Compute the per-sample average log-likelihood of each instance in the batch. 306 307 Parameters 308 ---------- 309 X : torch.tensor 310 A tensor with shape (Batch, N-points, Dimensions) 311 force_cpu_result : bool 312 Make sure that the resulting tensor is loaded on 313 the CPU regardless of the device used for the 314 computations (default: True). 315 316 Returns 317 ---------- 318 torch.tensor 319 tensor of shape (B,) with the log-likelihood for each instance in the batch. 320 """ 321 X = X.to(device=self._device, dtype=self._dtype) 322 if force_cpu_result: 323 return self.score_samples(X).nanmean(1).cpu() 324 return self.score_samples(X, force_cpu_result=False).nanmean(1) 325 326 327 def bic(self, X, force_cpu_result=True): 328 """ 329 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 330 331 Parameters 332 ---------- 333 X : torch.tensor 334 A tensor with shape (Batch, N-points, Dimensions) 335 force_cpu_result : bool 336 Make sure that the resulting tensor is loaded on 337 the CPU regardless of the device used for the 338 computations (default: True). 339 340 Returns 341 ---------- 342 torch.tensor 343 tensor of shape (B,) with the BIC value for each instance in the Batch. 344 """ 345 X = X.to(device=self._device, dtype=self._dtype) 346 scores = self.score(X, force_cpu_result=False) 347 valid_points = (~X.isnan()).all(2).sum(1) 348 result = -2 * scores * valid_points + self.n_parameters() * valid_points.log() 349 if force_cpu_result: 350 return result.cpu() 351 return result 352 353 354 def n_parameters(self): 355 """ 356 Returns the number of free parameters in the model for a single instance of the batch. 357 358 Returns 359 ---------- 360 int 361 number of parameters in the model 362 """ 363 n_features = self.means[0].shape[1] 364 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 365 mean_params = n_features * self._n_components 366 return int(cov_params + mean_params + self._n_components - 1) 367 368 369 def _init_clusters(self, X): 370 """ 371 Init the assignment component (cluster) assignment for B sets of N D-dimensional points. 372 373 Parameters 374 ---------- 375 X : torch.tensor 376 A tensor with shape (Batch, N-points, Dimensions) 377 """ 378 # If the assignment produced by kmeans has a component 379 # with less than two points, we rerun it to get a different 380 # assignment (up to 3 times). Having less than 2 points leads 381 # to bad covariance matrices that produce errors when trying 382 # to decompose/invert them. 383 retries = 0 384 while retries < 3: 385 _, assignment = self._kmeans(X, self._n_components) 386 _, counts = assignment.unique(return_counts=True) 387 if not torch.any(counts <= 2): 388 return assignment 389 retries += 1 390 return assignment 391 392 393 def _init_parameters(self, X): 394 B, N, D = X.shape 395 self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device) 396 for _ in range(self._n_components)] 397 self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device) 398 for _ in range(self._n_components)] 399 self._precisions_cholesky = [torch.empty(B, D, D, 400 dtype=self._dtype, 401 device=self._device) 402 for _ in range(self._n_components)] 403 self._pi = [torch.empty(B, dtype=self._dtype, device=self._device) 404 for _ in range(self._n_components)] 405 406 407 def _m_step(self, X, r, N_actual, N_actual_total, converged): 408 B, N, D = X.shape 409 # We update the means, covariances and weights 410 # for all instances in the batch that still 411 # have not converged. 412 for k in range(self._n_components): 413 self.means[k][~converged] = torch.div( 414 # the nominator is sum(r*X) 415 (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1), 416 # the denominator is normalizing by the number of valid points 417 N_actual[~converged, k].unsqueeze(1)) 418 419 self.covs[k][~converged] = self._get_covs(X[~converged], 420 self.means[k][~converged], 421 r[~converged, :, k], 422 N_actual[~converged, k]) 423 424 # We need to calculate the Cholesky decompositions of 425 # the precision matrices (the precision is the inverse 426 # of the covariance). However, due to numerical errors 427 # the covariance may lose its positive-definite property 428 # (which mathematically is guarenteed to have). Whenever 429 # that happens, we can no longer calculate the Cholesky 430 # decomposition. As a workaround, we substitute the cov 431 # matrix with a near covariance matrix that is positive 432 # definite. 433 covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged]) 434 bad_covs = errors > 0 435 if bad_covs.any(): 436 eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs]) 437 # Theoretically, we should be able to use much smaller 438 # min value here, but for some reason smaller ones sometimes 439 # fail to force the covariance matrix to be positive-definite. 440 new_eigvals = torch.clamp(eigvals, min=1e-5) 441 new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2) 442 self.covs[k][~converged][bad_covs] = new_covs 443 covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs) 444 self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky) 445 446 self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged] 447 448 449 def _kmeans(self, X, n_clusters=2, max_iter=10, tol=0.001): 450 """ 451 Clusters the points in each instance of the batch using k-means. 452 Points with nan values are assigned with value -1. 453 454 Parameters 455 ---------- 456 X : torch.tensor 457 A tensor with shape (Batch, N-points, Dimensions) 458 n_clusters : int 459 Number of clusters to find. 460 max_iter : int 461 Maximum number of iterations to perform. 462 tol : float 463 The convergence threshold. 464 """ 465 B, N, D = X.shape 466 C = n_clusters 467 valid_points = ~X.isnan().any(dim=2) 468 invalid_points_count = (~valid_points).sum(1) 469 centers = self._kmeans_pp(X, C, valid_points) 470 471 i = 0 472 diff = np.inf 473 while i < max_iter and diff > tol: 474 # Calculate the squared distance between each point and cluster centers 475 distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1) 476 assignment = distances.argmin(dim=2) 477 478 # Compute the new cluster center 479 cluster_sums = torch.zeros_like(centers) 480 cluster_counts = torch.zeros((B, C, 1), dtype=torch.float32, device=X.device) 481 # The nans are assigned to the first cluster. We want to ignore them. 482 # Hence, we use nat_to_num() to replace them with 0s and then we subtract 483 # the number of invalid points from the counts for the first cluster. 484 cluster_sums.scatter_add_(1, assignment.unsqueeze(-1).expand(-1, -1, D), X.nan_to_num()) 485 cluster_counts.scatter_add_(1, assignment.unsqueeze(-1), torch.ones_like(X[:, :, :1])) 486 cluster_counts[:, 0] -= invalid_points_count 487 new_centers = cluster_sums / cluster_counts.clamp_min(1e-8) 488 489 # Estimate how much change we get in the centers 490 diff = torch.norm(new_centers - centers, dim=(1, 2)).max() 491 492 centers = new_centers.nan_to_num() 493 i += 1 494 495 # Final assignment with updated centers 496 distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1) 497 assignment = torch.where(valid_points, distances.argmin(dim=2), -1) 498 499 return centers, assignment 500 501 502 def _select_random_valid_points(self, X, valid_mask): 503 B, N, D = X.shape 504 505 _, point_idx = valid_mask.nonzero(as_tuple=True) 506 counts = valid_mask.sum(1) 507 508 # Select random valid index. 509 # This is efficient, but quite tricky: 510 # nonzero(as_tuple=True) returns a list of the batch indices and corresponding 511 # point indices of valid points. For each instance in the batch, we get a 512 # random integer between 0 and the maximum possible number of valid points. 513 # To make sure that the selected integer is not larger than the number of 514 # valid points for each instance we mod that integer by counts. 515 # This basically gives us a random offset to select a point from a list 516 # of valid points for a given batch index. 517 rand_offsets = torch.randint(0, counts.max(), (B,), 518 generator=self._rand_generator, 519 device=X.device) % counts 520 521 # Here, cumsum(counts)-counts gives us the starting position of each instance in the batch 522 # in point_idx. E.g. if we have a batch of 3 instances with [5, 7, 3] valid points respectively, 523 # we would get batch starts = [0, 5, 12]. 524 batch_starts = torch.cumsum(counts, dim=0) - counts 525 chosen_indices = point_idx[batch_starts + rand_offsets] 526 527 selected_points = X[torch.arange(B, device=X.device), chosen_indices] 528 return selected_points 529 530 531 def _kmeans_pp(self, X, C, valid_points): 532 B, N, D = X.shape 533 device = X.device 534 std = self._nanstd(X) 535 centers = torch.empty(B, C, D, device=device) 536 537 # Randomly select the first center for each batch 538 rand_points = self._select_random_valid_points(X, valid_points) 539 centers[:, 0, :] = std * rand_points / rand_points.norm(dim=-1, keepdim=True) 540 541 # Each subsequent center would be calculated to be distant 542 # from the previous one 543 for k in range(1, C): 544 prev_centers = centers[:, k - 1, :].unsqueeze(1) 545 distances = (X - prev_centers).norm(dim=-1) 546 547 # By default kmeans++ takes as the next center the 548 # point that is furthest away. However, if there are 549 # outliers, they're likely to be selected, so here we 550 # ignore the top 10% of the most distant points. 551 threshold_idx = int(0.9 * N) 552 sorted_distances, sorted_indices = distances.sort(1) 553 554 # The standard kmeans++ algorithm selects an initial 555 # point at random for the first centroid and then for 556 # each cluster selects the point that is furthest away 557 # from the previous one. This is prone to selecting 558 # outliers that are very far away from all other points, 559 # leading to clusters with a single point. In the GMM 560 # fitting these clusters are problematic, because variance 561 # covariance metrics do not make sense anymore. 562 # To ameliorate this, we position the centroid at a point 563 # that is in the direction of the furthest point, 564 # but the length of the vector is equal to the 150% the 565 # standard deviation in the dataset. 566 # First, we get the most distant valid positions (after ignoring 567 # the top 10%). 568 max_valid_idx = _nanmax(sorted_distances[:, :threshold_idx], 1)[1] 569 # Those are indices that point to the sorting and not the original dataset. 570 # We need to map them through sorted_indices to obtain the indices for those points 571 # in the dataset X. 572 orig_indices = sorted_indices[torch.arange(B, device=device), max_valid_idx] 573 selected_points = X[torch.arange(B, device=device), orig_indices] 574 # Once we have the actual points, we calculate the new centers. 575 centers[:, k, :] = 1.5 * std * selected_points / selected_points.norm(dim=-1, keepdim=True) 576 return centers 577 578 579 def _get_covs(self, X, means, r, nums): 580 B, N, D = X.shape 581 # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T) 582 diffs = X - means.unsqueeze(1) 583 summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2)) 584 covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps) 585 return covs 586 587 588 def _get_precisions_cholesky(self, covs_cholesky): 589 B, D, D = covs_cholesky.shape 590 precisions_cholesky = torch.linalg.solve_triangular( 591 covs_cholesky, 592 torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1), 593 upper=False, 594 left=True).permute(0, 2, 1) 595 return precisions_cholesky.to(self._dtype) 596 597 598 def _nanstd(self, X): 599 valid = torch.sum(~X.isnan().any(2), 1) 600 return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5 601 602 603def _nanmax(T, dim): 604 """ 605 Compute the max along a given axis while ignoring NaNs. 606 """ 607 nan_mask = T.isnan() 608 T = torch.where(nan_mask, float('-inf'), T) 609 max_values, indices = T.max(dim=dim) 610 return max_values, indices 611 612 613def _estimate_gaussian_prob(X, mean, precisions_chol, dtype): 614 """ 615 Compute the probability of a batch of points X under 616 a batch of multivariate normal distributions. 617 618 Parameters 619 ---------- 620 X : torch.tensor 621 A tensor with shape (Batch, N-points, Dimensions). 622 Represents a batch of points. 623 mean : torch.tensor 624 The means of the distributions. Shape: (B, D) 625 precisions_chol : torch.tensor 626 Cholesky decompositions of the precisions matrices. Shape: (B, D, D) 627 dtype : torch.dtype 628 Data type of the result 629 630 Returns 631 ---------- 632 torch.tensor 633 tensor of shape (B, N) with probabilities 634 """ 635 B, N, D = X.shape 636 y = torch.bmm(X, precisions_chol) - torch.bmm(mean.unsqueeze(1), precisions_chol) 637 log_prob = y.pow(2).sum(2) 638 log_det = torch.diagonal(precisions_chol, dim1=1, dim2=2).log().sum(1) 639 return torch.exp( 640 -0.5 * (D * np.log(2 * np.pi) + log_prob) + log_det.unsqueeze(1))
40class GMM: 41 def __init__(self, 42 n_components, 43 max_iter=100, 44 device='cuda', 45 tol=0.001, 46 reg_covar=1e-6, 47 means_init=None, 48 weights_init=None, 49 precisions_init=None, 50 dtype=torch.float32, 51 random_seed=None): 52 """ 53 Initialize a Gaussian Mixture Models instance to fit. 54 55 Parameters 56 ---------- 57 n_components : int 58 Number of components (gaussians) in the model. 59 max_iter : int 60 Maximum number of EM iterations to perform. 61 device : torch.device 62 Which device to be used for the computations 63 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 64 tol : float 65 The convergence threshold. 66 reg_covar : float 67 Non-negative regularization added to the diagonal of covariance. 68 Allows to assure that the covariance matrices are all positive. 69 means_init : torch.tensor 70 User provided initialization means for all instances. The 71 tensor should have shape (Batch, Components, Dimensions). 72 If None (default) the means are going to be initialized 73 with modified kmeans++ and then refined with kmeans. 74 weights_init : torch.tensor 75 The user-provided initial weights. The tensor should have shape 76 (Batch, Components). If it is None, weights are initialized 77 depending on the kmeans++ & kmeans initialization. 78 precisions_init : torch.tensor 79 The user-provided initial precisions (inverse of the covariance matrices). 80 The tensor should have shape (Batch, Components, Dimension, Dimension). 81 If it is None, precisions are initialized depending on the kmeans++ & kmeans 82 initialization. 83 dtype : torch.dtype 84 Data type that will be used in the GMM instance. 85 random_seed : int 86 Controls the random seed that will be used 87 when initializing the model parameters. 88 """ 89 self._n_components = n_components 90 self._max_iter = max_iter 91 self._device = device 92 self._tol = tol 93 self._reg_covar = reg_covar 94 self._means_init = means_init 95 self._weights_init = weights_init 96 self._precisions_init = precisions_init 97 self._dtype = dtype 98 self._rand_generator = torch.Generator(device=device) 99 if random_seed: 100 self._rand_seed = random_seed 101 self._rand_generator.manual_seed(random_seed) 102 else: 103 self._rand_seed = None 104 105 106 def fit(self, X): 107 """ 108 Fit the GMM on the given tensor data. 109 110 Parameters 111 ---------- 112 X : torch.tensor 113 A tensor with shape (Batch, N-points, Dimensions) 114 """ 115 X = X.to(device=self._device, dtype=self._dtype) 116 117 B, N, D = X.shape 118 119 self._init_parameters(X) 120 component_mask = self._init_clusters(X) 121 122 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 123 for k in range(self._n_components): 124 r[:, :, k][component_mask == k] = 1 125 126 # This gives us the amount of points per component 127 # for each instance in the batch. It's necessary 128 # in order to handle missing points (with nan values). 129 N_actual = r.nansum(1) 130 N_actual_total = N_actual.sum(1) 131 132 converged = torch.full((B,), False, device=self._device) 133 134 # If at least one of the parameters is missing 135 # we calculate all parameters with the M-step. 136 if (self._means_init is None or 137 self._weights_init is None or 138 self._precisions_init is None): 139 self._m_step(X, r, N_actual, N_actual_total, converged) 140 141 # If any of the parameters have been provided by the 142 # user, we overwrite it with the provided value. 143 if self._means_init is not None: 144 self.means = [self._means_init[:, c, :] 145 for c in range(self._n_components)] 146 if self._weights_init is not None: 147 self._pi = [self._weights_init[:, c] 148 for c in range(self._n_components)] 149 if self._precisions_init is not None: 150 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 151 for c in range(self._n_components)] 152 153 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 154 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 155 156 iteration = 1 157 while iteration <= self._max_iter and not converged.all(): 158 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 159 160 # === E-STEP === 161 162 for k in range(self._n_components): 163 r[~converged, :, k] = torch.add( 164 _estimate_gaussian_prob( 165 X[~converged], 166 self.means[k][~converged], 167 self._precisions_cholesky[k][~converged], 168 self._dtype).log(), 169 self._pi[k][~converged].unsqueeze(1).log() 170 ) 171 log_prob_norm = r[~converged].logsumexp(2) 172 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 173 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 174 N_actual = r.nansum(1) 175 176 # If we have less than 2 points in a component it produces 177 # bad covariance matrices. Hence, we stop the iterations 178 # for the affected instances and continue with the rest. 179 unprocessable_instances = (N_actual < 2).any(1) 180 converged[unprocessable_instances] = True 181 182 # === M-STEP === 183 184 self._m_step(X, r, N_actual, N_actual_total, converged) 185 186 change = mean_log_prob_norm - prev_mean_log_prob_norm 187 188 # If the change for some instances in the batch 189 # are small enough, we mark those instances as 190 # converged and do not process them anymore. 191 small_change = change.abs() < self._tol 192 newly_converged = small_change & ~converged 193 converged[newly_converged] = True 194 self.convergence_iters[newly_converged] = iteration 195 196 iteration += 1 197 198 199 def predict_proba(self, X, force_cpu_result=True): 200 """ 201 Estimate the components' density for all samples 202 in all instances. 203 204 Parameters 205 ---------- 206 X : torch.tensor 207 A tensor with shape (Batch, N-points, Dimensions) 208 force_cpu_result : bool 209 Make sure that the resulting tensor is loaded on 210 the CPU regardless of the device used for the 211 computations (default: True). 212 213 Returns 214 ---------- 215 torch.tensor 216 tensor of shape (B, N, n_clusters) with probabilities. 217 The values at positions [I, S, :] will be the probabilities 218 of sample S in instance I to be assigned to each component. 219 """ 220 X = X.to(device=self._device, dtype=self._dtype) 221 B, N, D = X.shape 222 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 223 for k in range(self._n_components): 224 # Calculate weighted log probabilities 225 log_probs[:, :, k] = torch.add( 226 self._pi[k].log().unsqueeze(1), 227 _estimate_gaussian_prob(X, 228 self.means[k], 229 self._precisions_cholesky[k], 230 self._dtype).log()) 231 log_prob_norm = log_probs.logsumexp(2) 232 log_resp = log_probs - log_prob_norm.unsqueeze(2) 233 234 if force_cpu_result: 235 return log_resp.exp().cpu() 236 return log_resp.exp() 237 238 239 def predict(self, X, force_cpu_result=True): 240 """ 241 Predict the component assignment for the given tensor data. 242 243 Parameters 244 ---------- 245 X : torch.tensor 246 A tensor with shape (Batch, N-points, Dimensions) 247 force_cpu_result : bool 248 Make sure that the resulting tensor is loaded on 249 the CPU regardless of the device used for the 250 computations (default: True). 251 252 Returns 253 ---------- 254 torch.tensor 255 tensor of shape (B, N) with component ids as values. 256 """ 257 X = X.to(device=self._device, dtype=self._dtype) 258 B, N, D = X.shape 259 probs = torch.zeros(B, N, self._n_components, device=X.device) 260 for k in range(self._n_components): 261 probs[:, :, k] = _estimate_gaussian_prob(X, 262 self.means[k], 263 self._precisions_cholesky[k], 264 self._dtype) 265 if force_cpu_result: 266 torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu() 267 return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)) 268 269 270 def score_samples(self, X, force_cpu_result=True): 271 """ 272 Compute the log-likelihood of each point across all instances in the batch. 273 274 Parameters 275 ---------- 276 X : torch.tensor 277 A tensor with shape (Batch, N-points, Dimensions) 278 force_cpu_result : bool 279 Make sure that the resulting tensor is loaded on 280 the CPU regardless of the device used for the 281 computations (default: True). 282 283 Returns 284 ---------- 285 torch.tensor 286 tensor of shape (B, N) with the score for each point in the batch. 287 """ 288 X = X.to(device=self._device, dtype=self._dtype) 289 B, N, D = X.shape 290 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 291 for k in range(self._n_components): 292 # Calculate weighted log probabilities 293 log_probs[:, :, k] = torch.add( 294 self._pi[k].log().unsqueeze(1), 295 _estimate_gaussian_prob(X, 296 self.means[k], 297 self._precisions_cholesky[k], 298 self._dtype).log()) 299 if force_cpu_result: 300 return log_probs.logsumexp(2).cpu() 301 return log_probs.logsumexp(2) 302 303 304 def score(self, X, force_cpu_result=True): 305 """ 306 Compute the per-sample average log-likelihood of each instance in the batch. 307 308 Parameters 309 ---------- 310 X : torch.tensor 311 A tensor with shape (Batch, N-points, Dimensions) 312 force_cpu_result : bool 313 Make sure that the resulting tensor is loaded on 314 the CPU regardless of the device used for the 315 computations (default: True). 316 317 Returns 318 ---------- 319 torch.tensor 320 tensor of shape (B,) with the log-likelihood for each instance in the batch. 321 """ 322 X = X.to(device=self._device, dtype=self._dtype) 323 if force_cpu_result: 324 return self.score_samples(X).nanmean(1).cpu() 325 return self.score_samples(X, force_cpu_result=False).nanmean(1) 326 327 328 def bic(self, X, force_cpu_result=True): 329 """ 330 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 331 332 Parameters 333 ---------- 334 X : torch.tensor 335 A tensor with shape (Batch, N-points, Dimensions) 336 force_cpu_result : bool 337 Make sure that the resulting tensor is loaded on 338 the CPU regardless of the device used for the 339 computations (default: True). 340 341 Returns 342 ---------- 343 torch.tensor 344 tensor of shape (B,) with the BIC value for each instance in the Batch. 345 """ 346 X = X.to(device=self._device, dtype=self._dtype) 347 scores = self.score(X, force_cpu_result=False) 348 valid_points = (~X.isnan()).all(2).sum(1) 349 result = -2 * scores * valid_points + self.n_parameters() * valid_points.log() 350 if force_cpu_result: 351 return result.cpu() 352 return result 353 354 355 def n_parameters(self): 356 """ 357 Returns the number of free parameters in the model for a single instance of the batch. 358 359 Returns 360 ---------- 361 int 362 number of parameters in the model 363 """ 364 n_features = self.means[0].shape[1] 365 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 366 mean_params = n_features * self._n_components 367 return int(cov_params + mean_params + self._n_components - 1) 368 369 370 def _init_clusters(self, X): 371 """ 372 Init the assignment component (cluster) assignment for B sets of N D-dimensional points. 373 374 Parameters 375 ---------- 376 X : torch.tensor 377 A tensor with shape (Batch, N-points, Dimensions) 378 """ 379 # If the assignment produced by kmeans has a component 380 # with less than two points, we rerun it to get a different 381 # assignment (up to 3 times). Having less than 2 points leads 382 # to bad covariance matrices that produce errors when trying 383 # to decompose/invert them. 384 retries = 0 385 while retries < 3: 386 _, assignment = self._kmeans(X, self._n_components) 387 _, counts = assignment.unique(return_counts=True) 388 if not torch.any(counts <= 2): 389 return assignment 390 retries += 1 391 return assignment 392 393 394 def _init_parameters(self, X): 395 B, N, D = X.shape 396 self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device) 397 for _ in range(self._n_components)] 398 self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device) 399 for _ in range(self._n_components)] 400 self._precisions_cholesky = [torch.empty(B, D, D, 401 dtype=self._dtype, 402 device=self._device) 403 for _ in range(self._n_components)] 404 self._pi = [torch.empty(B, dtype=self._dtype, device=self._device) 405 for _ in range(self._n_components)] 406 407 408 def _m_step(self, X, r, N_actual, N_actual_total, converged): 409 B, N, D = X.shape 410 # We update the means, covariances and weights 411 # for all instances in the batch that still 412 # have not converged. 413 for k in range(self._n_components): 414 self.means[k][~converged] = torch.div( 415 # the nominator is sum(r*X) 416 (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1), 417 # the denominator is normalizing by the number of valid points 418 N_actual[~converged, k].unsqueeze(1)) 419 420 self.covs[k][~converged] = self._get_covs(X[~converged], 421 self.means[k][~converged], 422 r[~converged, :, k], 423 N_actual[~converged, k]) 424 425 # We need to calculate the Cholesky decompositions of 426 # the precision matrices (the precision is the inverse 427 # of the covariance). However, due to numerical errors 428 # the covariance may lose its positive-definite property 429 # (which mathematically is guarenteed to have). Whenever 430 # that happens, we can no longer calculate the Cholesky 431 # decomposition. As a workaround, we substitute the cov 432 # matrix with a near covariance matrix that is positive 433 # definite. 434 covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged]) 435 bad_covs = errors > 0 436 if bad_covs.any(): 437 eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs]) 438 # Theoretically, we should be able to use much smaller 439 # min value here, but for some reason smaller ones sometimes 440 # fail to force the covariance matrix to be positive-definite. 441 new_eigvals = torch.clamp(eigvals, min=1e-5) 442 new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2) 443 self.covs[k][~converged][bad_covs] = new_covs 444 covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs) 445 self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky) 446 447 self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged] 448 449 450 def _kmeans(self, X, n_clusters=2, max_iter=10, tol=0.001): 451 """ 452 Clusters the points in each instance of the batch using k-means. 453 Points with nan values are assigned with value -1. 454 455 Parameters 456 ---------- 457 X : torch.tensor 458 A tensor with shape (Batch, N-points, Dimensions) 459 n_clusters : int 460 Number of clusters to find. 461 max_iter : int 462 Maximum number of iterations to perform. 463 tol : float 464 The convergence threshold. 465 """ 466 B, N, D = X.shape 467 C = n_clusters 468 valid_points = ~X.isnan().any(dim=2) 469 invalid_points_count = (~valid_points).sum(1) 470 centers = self._kmeans_pp(X, C, valid_points) 471 472 i = 0 473 diff = np.inf 474 while i < max_iter and diff > tol: 475 # Calculate the squared distance between each point and cluster centers 476 distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1) 477 assignment = distances.argmin(dim=2) 478 479 # Compute the new cluster center 480 cluster_sums = torch.zeros_like(centers) 481 cluster_counts = torch.zeros((B, C, 1), dtype=torch.float32, device=X.device) 482 # The nans are assigned to the first cluster. We want to ignore them. 483 # Hence, we use nat_to_num() to replace them with 0s and then we subtract 484 # the number of invalid points from the counts for the first cluster. 485 cluster_sums.scatter_add_(1, assignment.unsqueeze(-1).expand(-1, -1, D), X.nan_to_num()) 486 cluster_counts.scatter_add_(1, assignment.unsqueeze(-1), torch.ones_like(X[:, :, :1])) 487 cluster_counts[:, 0] -= invalid_points_count 488 new_centers = cluster_sums / cluster_counts.clamp_min(1e-8) 489 490 # Estimate how much change we get in the centers 491 diff = torch.norm(new_centers - centers, dim=(1, 2)).max() 492 493 centers = new_centers.nan_to_num() 494 i += 1 495 496 # Final assignment with updated centers 497 distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1) 498 assignment = torch.where(valid_points, distances.argmin(dim=2), -1) 499 500 return centers, assignment 501 502 503 def _select_random_valid_points(self, X, valid_mask): 504 B, N, D = X.shape 505 506 _, point_idx = valid_mask.nonzero(as_tuple=True) 507 counts = valid_mask.sum(1) 508 509 # Select random valid index. 510 # This is efficient, but quite tricky: 511 # nonzero(as_tuple=True) returns a list of the batch indices and corresponding 512 # point indices of valid points. For each instance in the batch, we get a 513 # random integer between 0 and the maximum possible number of valid points. 514 # To make sure that the selected integer is not larger than the number of 515 # valid points for each instance we mod that integer by counts. 516 # This basically gives us a random offset to select a point from a list 517 # of valid points for a given batch index. 518 rand_offsets = torch.randint(0, counts.max(), (B,), 519 generator=self._rand_generator, 520 device=X.device) % counts 521 522 # Here, cumsum(counts)-counts gives us the starting position of each instance in the batch 523 # in point_idx. E.g. if we have a batch of 3 instances with [5, 7, 3] valid points respectively, 524 # we would get batch starts = [0, 5, 12]. 525 batch_starts = torch.cumsum(counts, dim=0) - counts 526 chosen_indices = point_idx[batch_starts + rand_offsets] 527 528 selected_points = X[torch.arange(B, device=X.device), chosen_indices] 529 return selected_points 530 531 532 def _kmeans_pp(self, X, C, valid_points): 533 B, N, D = X.shape 534 device = X.device 535 std = self._nanstd(X) 536 centers = torch.empty(B, C, D, device=device) 537 538 # Randomly select the first center for each batch 539 rand_points = self._select_random_valid_points(X, valid_points) 540 centers[:, 0, :] = std * rand_points / rand_points.norm(dim=-1, keepdim=True) 541 542 # Each subsequent center would be calculated to be distant 543 # from the previous one 544 for k in range(1, C): 545 prev_centers = centers[:, k - 1, :].unsqueeze(1) 546 distances = (X - prev_centers).norm(dim=-1) 547 548 # By default kmeans++ takes as the next center the 549 # point that is furthest away. However, if there are 550 # outliers, they're likely to be selected, so here we 551 # ignore the top 10% of the most distant points. 552 threshold_idx = int(0.9 * N) 553 sorted_distances, sorted_indices = distances.sort(1) 554 555 # The standard kmeans++ algorithm selects an initial 556 # point at random for the first centroid and then for 557 # each cluster selects the point that is furthest away 558 # from the previous one. This is prone to selecting 559 # outliers that are very far away from all other points, 560 # leading to clusters with a single point. In the GMM 561 # fitting these clusters are problematic, because variance 562 # covariance metrics do not make sense anymore. 563 # To ameliorate this, we position the centroid at a point 564 # that is in the direction of the furthest point, 565 # but the length of the vector is equal to the 150% the 566 # standard deviation in the dataset. 567 # First, we get the most distant valid positions (after ignoring 568 # the top 10%). 569 max_valid_idx = _nanmax(sorted_distances[:, :threshold_idx], 1)[1] 570 # Those are indices that point to the sorting and not the original dataset. 571 # We need to map them through sorted_indices to obtain the indices for those points 572 # in the dataset X. 573 orig_indices = sorted_indices[torch.arange(B, device=device), max_valid_idx] 574 selected_points = X[torch.arange(B, device=device), orig_indices] 575 # Once we have the actual points, we calculate the new centers. 576 centers[:, k, :] = 1.5 * std * selected_points / selected_points.norm(dim=-1, keepdim=True) 577 return centers 578 579 580 def _get_covs(self, X, means, r, nums): 581 B, N, D = X.shape 582 # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T) 583 diffs = X - means.unsqueeze(1) 584 summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2)) 585 covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps) 586 return covs 587 588 589 def _get_precisions_cholesky(self, covs_cholesky): 590 B, D, D = covs_cholesky.shape 591 precisions_cholesky = torch.linalg.solve_triangular( 592 covs_cholesky, 593 torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1), 594 upper=False, 595 left=True).permute(0, 2, 1) 596 return precisions_cholesky.to(self._dtype) 597 598 599 def _nanstd(self, X): 600 valid = torch.sum(~X.isnan().any(2), 1) 601 return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
41 def __init__(self, 42 n_components, 43 max_iter=100, 44 device='cuda', 45 tol=0.001, 46 reg_covar=1e-6, 47 means_init=None, 48 weights_init=None, 49 precisions_init=None, 50 dtype=torch.float32, 51 random_seed=None): 52 """ 53 Initialize a Gaussian Mixture Models instance to fit. 54 55 Parameters 56 ---------- 57 n_components : int 58 Number of components (gaussians) in the model. 59 max_iter : int 60 Maximum number of EM iterations to perform. 61 device : torch.device 62 Which device to be used for the computations 63 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 64 tol : float 65 The convergence threshold. 66 reg_covar : float 67 Non-negative regularization added to the diagonal of covariance. 68 Allows to assure that the covariance matrices are all positive. 69 means_init : torch.tensor 70 User provided initialization means for all instances. The 71 tensor should have shape (Batch, Components, Dimensions). 72 If None (default) the means are going to be initialized 73 with modified kmeans++ and then refined with kmeans. 74 weights_init : torch.tensor 75 The user-provided initial weights. The tensor should have shape 76 (Batch, Components). If it is None, weights are initialized 77 depending on the kmeans++ & kmeans initialization. 78 precisions_init : torch.tensor 79 The user-provided initial precisions (inverse of the covariance matrices). 80 The tensor should have shape (Batch, Components, Dimension, Dimension). 81 If it is None, precisions are initialized depending on the kmeans++ & kmeans 82 initialization. 83 dtype : torch.dtype 84 Data type that will be used in the GMM instance. 85 random_seed : int 86 Controls the random seed that will be used 87 when initializing the model parameters. 88 """ 89 self._n_components = n_components 90 self._max_iter = max_iter 91 self._device = device 92 self._tol = tol 93 self._reg_covar = reg_covar 94 self._means_init = means_init 95 self._weights_init = weights_init 96 self._precisions_init = precisions_init 97 self._dtype = dtype 98 self._rand_generator = torch.Generator(device=device) 99 if random_seed: 100 self._rand_seed = random_seed 101 self._rand_generator.manual_seed(random_seed) 102 else: 103 self._rand_seed = None
Initialize a Gaussian Mixture Models instance to fit.
Parameters
- n_components (int): Number of components (gaussians) in the model.
- max_iter (int): Maximum number of EM iterations to perform.
- device (torch.device):
Which device to be used for the computations
during the fitting (e.g
'cpu'
,'cuda'
,'cuda:0'
). - tol (float): The convergence threshold.
- reg_covar (float): Non-negative regularization added to the diagonal of covariance. Allows to assure that the covariance matrices are all positive.
- means_init (torch.tensor): User provided initialization means for all instances. The tensor should have shape (Batch, Components, Dimensions). If None (default) the means are going to be initialized with modified kmeans++ and then refined with kmeans.
- weights_init (torch.tensor): The user-provided initial weights. The tensor should have shape (Batch, Components). If it is None, weights are initialized depending on the kmeans++ & kmeans initialization.
- precisions_init (torch.tensor): The user-provided initial precisions (inverse of the covariance matrices). The tensor should have shape (Batch, Components, Dimension, Dimension). If it is None, precisions are initialized depending on the kmeans++ & kmeans initialization.
- dtype (torch.dtype): Data type that will be used in the GMM instance.
- random_seed (int): Controls the random seed that will be used when initializing the model parameters.
106 def fit(self, X): 107 """ 108 Fit the GMM on the given tensor data. 109 110 Parameters 111 ---------- 112 X : torch.tensor 113 A tensor with shape (Batch, N-points, Dimensions) 114 """ 115 X = X.to(device=self._device, dtype=self._dtype) 116 117 B, N, D = X.shape 118 119 self._init_parameters(X) 120 component_mask = self._init_clusters(X) 121 122 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 123 for k in range(self._n_components): 124 r[:, :, k][component_mask == k] = 1 125 126 # This gives us the amount of points per component 127 # for each instance in the batch. It's necessary 128 # in order to handle missing points (with nan values). 129 N_actual = r.nansum(1) 130 N_actual_total = N_actual.sum(1) 131 132 converged = torch.full((B,), False, device=self._device) 133 134 # If at least one of the parameters is missing 135 # we calculate all parameters with the M-step. 136 if (self._means_init is None or 137 self._weights_init is None or 138 self._precisions_init is None): 139 self._m_step(X, r, N_actual, N_actual_total, converged) 140 141 # If any of the parameters have been provided by the 142 # user, we overwrite it with the provided value. 143 if self._means_init is not None: 144 self.means = [self._means_init[:, c, :] 145 for c in range(self._n_components)] 146 if self._weights_init is not None: 147 self._pi = [self._weights_init[:, c] 148 for c in range(self._n_components)] 149 if self._precisions_init is not None: 150 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 151 for c in range(self._n_components)] 152 153 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 154 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 155 156 iteration = 1 157 while iteration <= self._max_iter and not converged.all(): 158 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 159 160 # === E-STEP === 161 162 for k in range(self._n_components): 163 r[~converged, :, k] = torch.add( 164 _estimate_gaussian_prob( 165 X[~converged], 166 self.means[k][~converged], 167 self._precisions_cholesky[k][~converged], 168 self._dtype).log(), 169 self._pi[k][~converged].unsqueeze(1).log() 170 ) 171 log_prob_norm = r[~converged].logsumexp(2) 172 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 173 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 174 N_actual = r.nansum(1) 175 176 # If we have less than 2 points in a component it produces 177 # bad covariance matrices. Hence, we stop the iterations 178 # for the affected instances and continue with the rest. 179 unprocessable_instances = (N_actual < 2).any(1) 180 converged[unprocessable_instances] = True 181 182 # === M-STEP === 183 184 self._m_step(X, r, N_actual, N_actual_total, converged) 185 186 change = mean_log_prob_norm - prev_mean_log_prob_norm 187 188 # If the change for some instances in the batch 189 # are small enough, we mark those instances as 190 # converged and do not process them anymore. 191 small_change = change.abs() < self._tol 192 newly_converged = small_change & ~converged 193 converged[newly_converged] = True 194 self.convergence_iters[newly_converged] = iteration 195 196 iteration += 1
Fit the GMM on the given tensor data.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
199 def predict_proba(self, X, force_cpu_result=True): 200 """ 201 Estimate the components' density for all samples 202 in all instances. 203 204 Parameters 205 ---------- 206 X : torch.tensor 207 A tensor with shape (Batch, N-points, Dimensions) 208 force_cpu_result : bool 209 Make sure that the resulting tensor is loaded on 210 the CPU regardless of the device used for the 211 computations (default: True). 212 213 Returns 214 ---------- 215 torch.tensor 216 tensor of shape (B, N, n_clusters) with probabilities. 217 The values at positions [I, S, :] will be the probabilities 218 of sample S in instance I to be assigned to each component. 219 """ 220 X = X.to(device=self._device, dtype=self._dtype) 221 B, N, D = X.shape 222 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 223 for k in range(self._n_components): 224 # Calculate weighted log probabilities 225 log_probs[:, :, k] = torch.add( 226 self._pi[k].log().unsqueeze(1), 227 _estimate_gaussian_prob(X, 228 self.means[k], 229 self._precisions_cholesky[k], 230 self._dtype).log()) 231 log_prob_norm = log_probs.logsumexp(2) 232 log_resp = log_probs - log_prob_norm.unsqueeze(2) 233 234 if force_cpu_result: 235 return log_resp.exp().cpu() 236 return log_resp.exp()
Estimate the components' density for all samples in all instances.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
- force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
- torch.tensor: tensor of shape (B, N, n_clusters) with probabilities. The values at positions [I, S, :] will be the probabilities of sample S in instance I to be assigned to each component.
239 def predict(self, X, force_cpu_result=True): 240 """ 241 Predict the component assignment for the given tensor data. 242 243 Parameters 244 ---------- 245 X : torch.tensor 246 A tensor with shape (Batch, N-points, Dimensions) 247 force_cpu_result : bool 248 Make sure that the resulting tensor is loaded on 249 the CPU regardless of the device used for the 250 computations (default: True). 251 252 Returns 253 ---------- 254 torch.tensor 255 tensor of shape (B, N) with component ids as values. 256 """ 257 X = X.to(device=self._device, dtype=self._dtype) 258 B, N, D = X.shape 259 probs = torch.zeros(B, N, self._n_components, device=X.device) 260 for k in range(self._n_components): 261 probs[:, :, k] = _estimate_gaussian_prob(X, 262 self.means[k], 263 self._precisions_cholesky[k], 264 self._dtype) 265 if force_cpu_result: 266 torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu() 267 return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2))
Predict the component assignment for the given tensor data.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
- force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
- torch.tensor: tensor of shape (B, N) with component ids as values.
270 def score_samples(self, X, force_cpu_result=True): 271 """ 272 Compute the log-likelihood of each point across all instances in the batch. 273 274 Parameters 275 ---------- 276 X : torch.tensor 277 A tensor with shape (Batch, N-points, Dimensions) 278 force_cpu_result : bool 279 Make sure that the resulting tensor is loaded on 280 the CPU regardless of the device used for the 281 computations (default: True). 282 283 Returns 284 ---------- 285 torch.tensor 286 tensor of shape (B, N) with the score for each point in the batch. 287 """ 288 X = X.to(device=self._device, dtype=self._dtype) 289 B, N, D = X.shape 290 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 291 for k in range(self._n_components): 292 # Calculate weighted log probabilities 293 log_probs[:, :, k] = torch.add( 294 self._pi[k].log().unsqueeze(1), 295 _estimate_gaussian_prob(X, 296 self.means[k], 297 self._precisions_cholesky[k], 298 self._dtype).log()) 299 if force_cpu_result: 300 return log_probs.logsumexp(2).cpu() 301 return log_probs.logsumexp(2)
Compute the log-likelihood of each point across all instances in the batch.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
- force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
- torch.tensor: tensor of shape (B, N) with the score for each point in the batch.
304 def score(self, X, force_cpu_result=True): 305 """ 306 Compute the per-sample average log-likelihood of each instance in the batch. 307 308 Parameters 309 ---------- 310 X : torch.tensor 311 A tensor with shape (Batch, N-points, Dimensions) 312 force_cpu_result : bool 313 Make sure that the resulting tensor is loaded on 314 the CPU regardless of the device used for the 315 computations (default: True). 316 317 Returns 318 ---------- 319 torch.tensor 320 tensor of shape (B,) with the log-likelihood for each instance in the batch. 321 """ 322 X = X.to(device=self._device, dtype=self._dtype) 323 if force_cpu_result: 324 return self.score_samples(X).nanmean(1).cpu() 325 return self.score_samples(X, force_cpu_result=False).nanmean(1)
Compute the per-sample average log-likelihood of each instance in the batch.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
- force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
- torch.tensor: tensor of shape (B,) with the log-likelihood for each instance in the batch.
328 def bic(self, X, force_cpu_result=True): 329 """ 330 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 331 332 Parameters 333 ---------- 334 X : torch.tensor 335 A tensor with shape (Batch, N-points, Dimensions) 336 force_cpu_result : bool 337 Make sure that the resulting tensor is loaded on 338 the CPU regardless of the device used for the 339 computations (default: True). 340 341 Returns 342 ---------- 343 torch.tensor 344 tensor of shape (B,) with the BIC value for each instance in the Batch. 345 """ 346 X = X.to(device=self._device, dtype=self._dtype) 347 scores = self.score(X, force_cpu_result=False) 348 valid_points = (~X.isnan()).all(2).sum(1) 349 result = -2 * scores * valid_points + self.n_parameters() * valid_points.log() 350 if force_cpu_result: 351 return result.cpu() 352 return result
Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
- force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
- torch.tensor: tensor of shape (B,) with the BIC value for each instance in the Batch.
355 def n_parameters(self): 356 """ 357 Returns the number of free parameters in the model for a single instance of the batch. 358 359 Returns 360 ---------- 361 int 362 number of parameters in the model 363 """ 364 n_features = self.means[0].shape[1] 365 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 366 mean_params = n_features * self._n_components 367 return int(cov_params + mean_params + self._n_components - 1)
Returns the number of free parameters in the model for a single instance of the batch.
Returns
- int: number of parameters in the model