gmm_gpu.gmm
Provides a GMM class for fitting multiple instances of Gaussian Mixture Models .
This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.
If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe Pomegranate would be a better option.
Example usage:
Import pytorch and the GMM class
>>> from gmm_gpu.gmm import GMM
>>> import torch
Generate some test data: We create a batch of 1000 instances, each with 200 random points. Half of the points are sampled from distribution centered at the origin (0, 0) and the other half from a distribution centered at (1.5, 1.5).
>>> X1 = torch.randn(1000, 100, 2)
>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
>>> X = torch.cat([X1, X2], dim=1)
Fit the model
>>> gmm = GMM(n_components=2, device='cuda')
>>> gmm.fit(X)
Predict the components: This will return a matrix with shape (1000, 200) where each value is the predicted component for the point.
>>> gmm.predict(X)
1""" 2Provides a GMM class for fitting multiple instances of `Gaussian Mixture Models <https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model>`_. 3 4This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. 5You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then 6send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points. 7 8If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe `Pomegranate <https://github.com/jmschrei/pomegranate>`_ would be a better option. 9 10### Example usage: 11Import pytorch and the GMM class 12>>> from gmm_gpu.gmm import GMM 13>>> import torch 14 15Generate some test data: 16We create a batch of 1000 instances, each 17with 200 random points. Half of the points 18are sampled from distribution centered at 19the origin (0, 0) and the other half from 20a distribution centered at (1.5, 1.5). 21>>> X1 = torch.randn(1000, 100, 2) 22>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5]) 23>>> X = torch.cat([X1, X2], dim=1) 24 25Fit the model 26>>> gmm = GMM(n_components=2, device='cuda') 27>>> gmm.fit(X) 28 29Predict the components: 30This will return a matrix with shape (1000, 200) where 31each value is the predicted component for the point. 32>>> gmm.predict(X) 33""" 34 35import math 36 37import torch 38import numpy as np 39 40 41class GMM: 42 def __init__(self, 43 n_components, 44 max_iter=100, 45 device='cuda', 46 tol=0.001, 47 reg_covar=1e-6, 48 means_init=None, 49 weights_init=None, 50 precisions_init=None, 51 dtype=torch.float32, 52 random_seed=None): 53 """ 54 Initialize a Gaussian Mixture Models instance to fit. 55 56 Parameters 57 ---------- 58 n_components : int 59 Number of components (gaussians) in the model. 60 max_iter : int 61 Maximum number of EM iterations to perform. 62 device : torch.device 63 Which device to be used for the computations 64 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 65 tol : float 66 The convergence threshold. 67 reg_covar : float 68 Non-negative regularization added to the diagonal of covariance. 69 Allows to assure that the covariance matrices are all positive. 70 means_init : torch.tensor 71 User provided initialization means for all instances. The 72 tensor should have shape (Batch, Components, Dimensions). 73 If None (default) the means are going to be initialized 74 with modified kmeans++ and then refined with kmeans. 75 weights_init : torch.tensor 76 The user-provided initial weights. The tensor should have shape 77 (Batch, Components). If it is None, weights are initialized 78 depending on the kmeans++ & kmeans initialization. 79 precisions_init : torch.tensor 80 The user-provided initial precisions (inverse of the covariance matrices). 81 The tensor should have shape (Batch, Components, Dimension, Dimension). 82 If it is None, precisions are initialized depending on the kmeans++ & kmeans 83 initialization. 84 dtype : torch.dtype 85 Data type that will be used in the GMM instance. 86 random_seed : int 87 Controls the random seed that will be used 88 when initializing the model parameters. 89 """ 90 self._n_components = n_components 91 self._max_iter = max_iter 92 self._device = device 93 self._tol = tol 94 self._reg_covar = reg_covar 95 self._means_init = means_init 96 self._weights_init = weights_init 97 self._precisions_init = precisions_init 98 self._dtype = dtype 99 self._rand_generator = torch.Generator(device=device) 100 if random_seed: 101 self._rand_seed = random_seed 102 self._rand_generator.manual_seed(random_seed) 103 else: 104 self._rand_seed = None 105 106 107 def fit(self, X): 108 """ 109 Fit the GMM on the given tensor data. 110 111 Parameters 112 ---------- 113 X : torch.tensor 114 A tensor with shape (Batch, N-points, Dimensions) 115 """ 116 X = X.to(self._dtype) 117 if X.device.type != self._device: 118 X = X.to(self._device) 119 120 B, N, D = X.shape 121 122 self._init_parameters(X) 123 component_mask = self._init_clusters(X) 124 125 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 126 for k in range(self._n_components): 127 r[:, :, k][component_mask == k] = 1 128 129 # This gives us the amount of points per component 130 # for each instance in the batch. It's necessary 131 # in order to handle missing points (with nan values). 132 N_actual = r.nansum(1) 133 N_actual_total = N_actual.sum(1) 134 135 converged = torch.full((B,), False, device=self._device) 136 137 # If we have less than 2 points in a component it produces 138 # bad covariance matrices. Hence, we stop the iterations 139 # for the affected instances and continue with the rest. 140 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 141 converged[single_component] = True 142 143 # If at least one of the parameters is missing 144 # we calculate all parameters with the M-step. 145 if (self._means_init is None or 146 self._weights_init is None or 147 self._precisions_init is None): 148 self._m_step(X, r, N_actual, N_actual_total, converged) 149 150 # If any of the parameters have been provided by the 151 # user, we overwrite it with the provided value. 152 if self._means_init is not None: 153 self.means = [self._means_init[:, c, :] 154 for c in range(self._n_components)] 155 if self._weights_init is not None: 156 self._pi = [self._weights_init[:, c] 157 for c in range(self._n_components)] 158 if self._precisions_init is not None: 159 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 160 for c in range(self._n_components)] 161 162 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 163 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 164 165 iteration = 1 166 while iteration <= self._max_iter and not converged.all(): 167 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 168 169 # === E-STEP === 170 171 for k in range(self._n_components): 172 r[~converged, :, k] = torch.add( 173 _estimate_gaussian_prob( 174 X[~converged], 175 self.means[k][~converged], 176 self._precisions_cholesky[k][~converged], 177 self._dtype).log(), 178 self._pi[k][~converged].unsqueeze(1).log() 179 ) 180 log_prob_norm = r[~converged].logsumexp(2) 181 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 182 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 183 N_actual = r.nansum(1) 184 185 # If we have less than 2 points in a component it produces 186 # bad covariance matrices. Hence, we stop the iterations 187 # for the affected instances and continue with the rest. 188 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 189 converged[single_component] = True 190 191 # === M-STEP === 192 193 self._m_step(X, r, N_actual, N_actual_total, converged) 194 195 change = mean_log_prob_norm - prev_mean_log_prob_norm 196 197 # If the change for some instances in the batch 198 # are small enough, we mark those instances as 199 # converged and do not process them anymore. 200 small_change = change.abs() < self._tol 201 newly_converged = small_change & ~converged 202 converged[newly_converged] = True 203 self.convergence_iters[newly_converged] = iteration 204 205 iteration += 1 206 207 208 def predict(self, X): 209 """ 210 Predict the component assignment for the given tensor data. 211 212 Parameters 213 ---------- 214 X : torch.tensor 215 A tensor with shape (Batch, N-points, Dimensions) 216 217 Returns 218 ---------- 219 torch.tensor 220 tensor of shape (B, N) with component ids as values. 221 """ 222 if X.dtype == self._dtype: 223 X = X.to(self._dtype) 224 if X.device.type != self._device: 225 X = X.to(self._device) 226 B, N, D = X.shape 227 probs = torch.zeros(B, N, self._n_components, device=X.device) 228 for k in range(self._n_components): 229 probs[:, :, k] = _estimate_gaussian_prob(X, 230 self.means[k], 231 self._precisions_cholesky[k], 232 self._dtype) 233 return probs.argmax(2).cpu() 234 235 236 def score_samples(self, X): 237 """ 238 Compute the log-likelihood of each point across all instances in the batch. 239 240 Parameters 241 ---------- 242 X : torch.tensor 243 A tensor with shape (Batch, N-points, Dimensions) 244 245 Returns 246 ---------- 247 torch.tensor 248 tensor of shape (B, N) with the score for each point in the batch. 249 """ 250 if X.device.type != self._device: 251 X = X.to(self._device) 252 X = X.to(self._dtype) 253 B, N, D = X.shape 254 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 255 for k in range(self._n_components): 256 # Calculate weighted log probabilities 257 log_probs[:, :, k] = torch.add( 258 self._pi[k].log().unsqueeze(1), 259 _estimate_gaussian_prob(X, 260 self.means[k], 261 self._precisions_cholesky[k], 262 self._dtype).log() 263 ) 264 return log_probs.logsumexp(2).cpu() 265 266 267 def score(self, X): 268 """ 269 Compute the per-sample average log-likelihood of each instance in the batch. 270 271 Parameters 272 ---------- 273 X : torch.tensor 274 A tensor with shape (Batch, N-points, Dimensions) 275 276 Returns 277 ---------- 278 torch.tensor 279 tensor of shape (B,) with the log-likelihood for each instance in the batch. 280 """ 281 return self.score_samples(X).nanmean(1).cpu() 282 283 284 def bic(self, X): 285 """ 286 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 287 288 Parameters 289 ---------- 290 X : torch.tensor 291 A tensor with shape (Batch, N-points, Dimensions) 292 293 Returns 294 ---------- 295 torch.tensor 296 tensor of shape (B,) with the BIC value for each instance in the Batch. 297 """ 298 scores = self.score(X) 299 valid_points = (~X.isnan()).all(2).sum(1) 300 return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points) 301 302 303 def n_parameters(self): 304 """ 305 Returns the number of free parameters in the model for a single instance of the batch. 306 307 Returns 308 ---------- 309 int 310 number of parameters in the model 311 """ 312 n_features = self.means[0].shape[1] 313 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 314 mean_params = n_features * self._n_components 315 return int(cov_params + mean_params + self._n_components - 1) 316 317 318 def _init_clusters(self, X): 319 """ 320 Init the assignment component (cluster) assignment for B sets of N D-dimensional points. 321 322 Parameters 323 ---------- 324 X : torch.tensor 325 A tensor with shape (Batch, N-points, Dimensions) 326 """ 327 # If the assignment produced by kmeans has a component 328 # with less than two points, we rerun it to get a different 329 # assignment (up to 3 times). Having less than 2 points leads 330 # to bad covariance matrices that produce errors when trying 331 # to decompose/invert them. 332 retries = 0 333 while retries < 3: 334 seed = self._rand_seed + retries if self._rand_seed else None 335 _, assignment = self._kmeans(X, 336 self._n_components, 337 random_seed=seed) 338 _, counts = assignment.unique(return_counts=True) 339 if not torch.any(counts <= 2): 340 return assignment 341 retries += 1 342 return assignment 343 344 345 def _init_parameters(self, X): 346 B, N, D = X.shape 347 self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device) 348 for _ in range(self._n_components)] 349 self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device) 350 for _ in range(self._n_components)] 351 self._precisions_cholesky = [torch.empty(B, D, D, 352 dtype=self._dtype, 353 device=self._device) 354 for _ in range(self._n_components)] 355 self._pi = [torch.empty(B, dtype=self._dtype, device=self._device) 356 for _ in range(self._n_components)] 357 358 359 def _m_step(self, X, r, N_actual, N_actual_total, converged): 360 B, N, D = X.shape 361 # We update the means, covariances and weights 362 # for all instances in the batch that still 363 # have not converged. 364 for k in range(self._n_components): 365 self.means[k][~converged] = torch.div( 366 # the nominator is sum(r*X) 367 (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1), 368 # the denominator is normalizing by the number of valid points 369 N_actual[~converged, k].unsqueeze(1)) 370 371 self.covs[k][~converged] = self._get_covs(X[~converged], 372 self.means[k][~converged], 373 r[~converged, :, k], 374 N_actual[~converged, k]) 375 376 # We need to calculate the Cholesky decompositions of 377 # the precision matrices (the precision is the inverse 378 # of the covariance). However, due to numerical errors 379 # the covariance may lose its positive-definite property 380 # (which mathematically is guarenteed to have). Whenever 381 # that happens, we can no longer calculate the Cholesky 382 # decomposition. As a workaround, we substitute the cov 383 # matrix with a near covariance matrix that is positive 384 # definite. 385 covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged]) 386 bad_covs = errors > 0 387 if bad_covs.any(): 388 eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs]) 389 # Theoretically, we should be able to use much smaller 390 # min value here, but for some reason smaller ones sometimes 391 # fail to force the covariance matrix to be positive-definite. 392 new_eigvals = torch.clamp(eigvals, min=1e-5) 393 new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2) 394 self.covs[k][~converged][bad_covs] = new_covs 395 covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs) 396 self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky) 397 398 self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged] 399 400 401 def _kmeans(self, X, n_clusters=2, max_iter=20, tol=0.001, random_seed=None): 402 """ 403 Clusters the points in each instance of the batch using k-means. 404 Points with nan values are assigned with value -1. 405 406 Parameters 407 ---------- 408 X : torch.tensor 409 A tensor with shape (Batch, N-points, Dimensions) 410 n_clusters : int 411 Number of clusters to find. 412 max_iter : int 413 Maximum number of iterations to perform. 414 tol : float 415 The convergence threshold. 416 """ 417 B, N, D = X.shape 418 C = n_clusters 419 valid_points = ~X.isnan().any(2) 420 centers = self._kmeans_pp(X, C, valid_points, random_seed=random_seed) 421 distances = torch.empty(B, N, C, device=self._device) 422 i = 0 423 diff = np.inf 424 while i < max_iter and diff > tol: 425 # Calculate the distance between each point and cluster centers 426 for c in range(C): 427 distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5 428 # Assign each point to the cluster with closest center 429 assignment = distances.argmin(2) 430 # Recalculate cluster centers 431 new_centers = torch.empty(B, C, D, device=self._device) 432 for c in range(C): 433 cluster_mask = (assignment == c).unsqueeze(2).repeat(1, 1, D) 434 new_centers[:, c, :] = torch.where(cluster_mask, X, np.nan).nanmedian(1).values 435 # Estimate how much change we get in the centers 436 diff = (new_centers - centers).mean(1).max() 437 centers = new_centers 438 i += 1 439 for c in range(C): 440 distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5 441 # Assign each point to the cluster with closest center. 442 # Invalid points are assigned -1. 443 assignment = torch.where(valid_points, distances.argmin(2), -1) 444 return centers, assignment 445 446 447 def _kmeans_pp(self, X, C, valid_points, random_seed=None): 448 valid_mask = valid_points.clone() 449 B, N, D = X.shape 450 centers = torch.empty(B, C, D, device=self._device) 451 indices = torch.arange(0, N, device=self._device) 452 rand = np.random.default_rng(random_seed) 453 std = self._nanstd(X) 454 for b in range(B): 455 point_index = rand.choice(indices[valid_mask[b]].cpu().numpy()) 456 point = X[b, point_index, :] 457 point_len = (X[b, point_index, :].pow(2).sum(-1) ** 0.5) 458 centers[b, 0, :] = std[b]*point/point_len 459 valid_mask[b, point_index] = False 460 for k in range(1, C): 461 prev_center = centers[b, k-1, :] 462 distances = ((X[b, valid_mask[b], :] - prev_center) ** 2).sum(-1) ** 0.5 463 # By default kmeans++ takes as the next center the 464 # point that is furthest away. However, if there are 465 # outliers, they're likely to be selected, so here we 466 # ignore the top 10% of the most distant points. 467 max_dist_index = torch.argsort(distances)[math.floor(0.9*distances.shape[0])] 468 469 # The distances are calculated on a subset of the points 470 # so we need to convert the index of the furthest point 471 # to the index of the point in the whole dataset. 472 point_index = indices[valid_mask[b]][max_dist_index] 473 474 # The standard kmeans++ algorithm selects an initial 475 # point at random for the first centroid and then for 476 # each cluster selects the point that is furthest away 477 # from the previous one. This is prone to selecting 478 # outliers that are very far away from all other points, 479 # leading to clusters with a single point. In the GMM 480 # fitting these clusters are problematic, because variance 481 # covariance metrics do not make sense anymore. 482 # To ameliorate this, I position the centroid at a point 483 # that pointing in the direction of the furthest point, 484 # but the length of the vector is equal to the standard 485 # deviation in the dataset. 486 centers[b, k, :] = std[b] * X[b, point_index, :] / distances[max_dist_index] 487 return centers 488 489 490 def _get_covs(self, X, means, r, nums): 491 B, N, D = X.shape 492 # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T) 493 diffs = X - means.unsqueeze(1) 494 summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2)) 495 covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps) 496 return covs 497 498 499 def _get_precisions_cholesky(self, covs_cholesky): 500 B, D, D = covs_cholesky.shape 501 precisions_cholesky = torch.linalg.solve_triangular( 502 covs_cholesky, 503 torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1), 504 upper=False, 505 left=True).permute(0, 2, 1) 506 return precisions_cholesky.to(self._dtype) 507 508 509 def _nanstd(self, X): 510 valid = torch.sum(~X.isnan().any(2), 1) 511 return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5 512 513 514def _estimate_gaussian_prob(X, mean, precisions_chol, dtype): 515 """ 516 Compute the probability of a batch of points X under 517 a batch of multivariate normal distributions. 518 519 Parameters 520 ---------- 521 X : torch.tensor 522 A tensor with shape (Batch, N-points, Dimensions). 523 Represents a batch of points. 524 mean : torch.tensor 525 The means of the distributions. Shape: (B, D) 526 precisions_chol : torch.tensor 527 Cholesky decompositions of the precisions matrices. Shape: (B, D, D) 528 dtype : torch.dtype 529 Data type of the result 530 531 Returns 532 ---------- 533 torch.tensor 534 tensor of shape (B, N) with probabilities 535 """ 536 B, N, D = X.shape 537 y = torch.bmm(X, precisions_chol) - torch.bmm(mean.unsqueeze(1), precisions_chol) 538 log_prob = y.pow(2).sum(2) 539 log_det = torch.diagonal(precisions_chol, dim1=1, dim2=2).log().sum(1) 540 return torch.exp( 541 -0.5 * (D * np.log(2 * np.pi) + log_prob) + log_det.unsqueeze(1))
42class GMM: 43 def __init__(self, 44 n_components, 45 max_iter=100, 46 device='cuda', 47 tol=0.001, 48 reg_covar=1e-6, 49 means_init=None, 50 weights_init=None, 51 precisions_init=None, 52 dtype=torch.float32, 53 random_seed=None): 54 """ 55 Initialize a Gaussian Mixture Models instance to fit. 56 57 Parameters 58 ---------- 59 n_components : int 60 Number of components (gaussians) in the model. 61 max_iter : int 62 Maximum number of EM iterations to perform. 63 device : torch.device 64 Which device to be used for the computations 65 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 66 tol : float 67 The convergence threshold. 68 reg_covar : float 69 Non-negative regularization added to the diagonal of covariance. 70 Allows to assure that the covariance matrices are all positive. 71 means_init : torch.tensor 72 User provided initialization means for all instances. The 73 tensor should have shape (Batch, Components, Dimensions). 74 If None (default) the means are going to be initialized 75 with modified kmeans++ and then refined with kmeans. 76 weights_init : torch.tensor 77 The user-provided initial weights. The tensor should have shape 78 (Batch, Components). If it is None, weights are initialized 79 depending on the kmeans++ & kmeans initialization. 80 precisions_init : torch.tensor 81 The user-provided initial precisions (inverse of the covariance matrices). 82 The tensor should have shape (Batch, Components, Dimension, Dimension). 83 If it is None, precisions are initialized depending on the kmeans++ & kmeans 84 initialization. 85 dtype : torch.dtype 86 Data type that will be used in the GMM instance. 87 random_seed : int 88 Controls the random seed that will be used 89 when initializing the model parameters. 90 """ 91 self._n_components = n_components 92 self._max_iter = max_iter 93 self._device = device 94 self._tol = tol 95 self._reg_covar = reg_covar 96 self._means_init = means_init 97 self._weights_init = weights_init 98 self._precisions_init = precisions_init 99 self._dtype = dtype 100 self._rand_generator = torch.Generator(device=device) 101 if random_seed: 102 self._rand_seed = random_seed 103 self._rand_generator.manual_seed(random_seed) 104 else: 105 self._rand_seed = None 106 107 108 def fit(self, X): 109 """ 110 Fit the GMM on the given tensor data. 111 112 Parameters 113 ---------- 114 X : torch.tensor 115 A tensor with shape (Batch, N-points, Dimensions) 116 """ 117 X = X.to(self._dtype) 118 if X.device.type != self._device: 119 X = X.to(self._device) 120 121 B, N, D = X.shape 122 123 self._init_parameters(X) 124 component_mask = self._init_clusters(X) 125 126 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 127 for k in range(self._n_components): 128 r[:, :, k][component_mask == k] = 1 129 130 # This gives us the amount of points per component 131 # for each instance in the batch. It's necessary 132 # in order to handle missing points (with nan values). 133 N_actual = r.nansum(1) 134 N_actual_total = N_actual.sum(1) 135 136 converged = torch.full((B,), False, device=self._device) 137 138 # If we have less than 2 points in a component it produces 139 # bad covariance matrices. Hence, we stop the iterations 140 # for the affected instances and continue with the rest. 141 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 142 converged[single_component] = True 143 144 # If at least one of the parameters is missing 145 # we calculate all parameters with the M-step. 146 if (self._means_init is None or 147 self._weights_init is None or 148 self._precisions_init is None): 149 self._m_step(X, r, N_actual, N_actual_total, converged) 150 151 # If any of the parameters have been provided by the 152 # user, we overwrite it with the provided value. 153 if self._means_init is not None: 154 self.means = [self._means_init[:, c, :] 155 for c in range(self._n_components)] 156 if self._weights_init is not None: 157 self._pi = [self._weights_init[:, c] 158 for c in range(self._n_components)] 159 if self._precisions_init is not None: 160 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 161 for c in range(self._n_components)] 162 163 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 164 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 165 166 iteration = 1 167 while iteration <= self._max_iter and not converged.all(): 168 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 169 170 # === E-STEP === 171 172 for k in range(self._n_components): 173 r[~converged, :, k] = torch.add( 174 _estimate_gaussian_prob( 175 X[~converged], 176 self.means[k][~converged], 177 self._precisions_cholesky[k][~converged], 178 self._dtype).log(), 179 self._pi[k][~converged].unsqueeze(1).log() 180 ) 181 log_prob_norm = r[~converged].logsumexp(2) 182 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 183 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 184 N_actual = r.nansum(1) 185 186 # If we have less than 2 points in a component it produces 187 # bad covariance matrices. Hence, we stop the iterations 188 # for the affected instances and continue with the rest. 189 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 190 converged[single_component] = True 191 192 # === M-STEP === 193 194 self._m_step(X, r, N_actual, N_actual_total, converged) 195 196 change = mean_log_prob_norm - prev_mean_log_prob_norm 197 198 # If the change for some instances in the batch 199 # are small enough, we mark those instances as 200 # converged and do not process them anymore. 201 small_change = change.abs() < self._tol 202 newly_converged = small_change & ~converged 203 converged[newly_converged] = True 204 self.convergence_iters[newly_converged] = iteration 205 206 iteration += 1 207 208 209 def predict(self, X): 210 """ 211 Predict the component assignment for the given tensor data. 212 213 Parameters 214 ---------- 215 X : torch.tensor 216 A tensor with shape (Batch, N-points, Dimensions) 217 218 Returns 219 ---------- 220 torch.tensor 221 tensor of shape (B, N) with component ids as values. 222 """ 223 if X.dtype == self._dtype: 224 X = X.to(self._dtype) 225 if X.device.type != self._device: 226 X = X.to(self._device) 227 B, N, D = X.shape 228 probs = torch.zeros(B, N, self._n_components, device=X.device) 229 for k in range(self._n_components): 230 probs[:, :, k] = _estimate_gaussian_prob(X, 231 self.means[k], 232 self._precisions_cholesky[k], 233 self._dtype) 234 return probs.argmax(2).cpu() 235 236 237 def score_samples(self, X): 238 """ 239 Compute the log-likelihood of each point across all instances in the batch. 240 241 Parameters 242 ---------- 243 X : torch.tensor 244 A tensor with shape (Batch, N-points, Dimensions) 245 246 Returns 247 ---------- 248 torch.tensor 249 tensor of shape (B, N) with the score for each point in the batch. 250 """ 251 if X.device.type != self._device: 252 X = X.to(self._device) 253 X = X.to(self._dtype) 254 B, N, D = X.shape 255 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 256 for k in range(self._n_components): 257 # Calculate weighted log probabilities 258 log_probs[:, :, k] = torch.add( 259 self._pi[k].log().unsqueeze(1), 260 _estimate_gaussian_prob(X, 261 self.means[k], 262 self._precisions_cholesky[k], 263 self._dtype).log() 264 ) 265 return log_probs.logsumexp(2).cpu() 266 267 268 def score(self, X): 269 """ 270 Compute the per-sample average log-likelihood of each instance in the batch. 271 272 Parameters 273 ---------- 274 X : torch.tensor 275 A tensor with shape (Batch, N-points, Dimensions) 276 277 Returns 278 ---------- 279 torch.tensor 280 tensor of shape (B,) with the log-likelihood for each instance in the batch. 281 """ 282 return self.score_samples(X).nanmean(1).cpu() 283 284 285 def bic(self, X): 286 """ 287 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 288 289 Parameters 290 ---------- 291 X : torch.tensor 292 A tensor with shape (Batch, N-points, Dimensions) 293 294 Returns 295 ---------- 296 torch.tensor 297 tensor of shape (B,) with the BIC value for each instance in the Batch. 298 """ 299 scores = self.score(X) 300 valid_points = (~X.isnan()).all(2).sum(1) 301 return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points) 302 303 304 def n_parameters(self): 305 """ 306 Returns the number of free parameters in the model for a single instance of the batch. 307 308 Returns 309 ---------- 310 int 311 number of parameters in the model 312 """ 313 n_features = self.means[0].shape[1] 314 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 315 mean_params = n_features * self._n_components 316 return int(cov_params + mean_params + self._n_components - 1) 317 318 319 def _init_clusters(self, X): 320 """ 321 Init the assignment component (cluster) assignment for B sets of N D-dimensional points. 322 323 Parameters 324 ---------- 325 X : torch.tensor 326 A tensor with shape (Batch, N-points, Dimensions) 327 """ 328 # If the assignment produced by kmeans has a component 329 # with less than two points, we rerun it to get a different 330 # assignment (up to 3 times). Having less than 2 points leads 331 # to bad covariance matrices that produce errors when trying 332 # to decompose/invert them. 333 retries = 0 334 while retries < 3: 335 seed = self._rand_seed + retries if self._rand_seed else None 336 _, assignment = self._kmeans(X, 337 self._n_components, 338 random_seed=seed) 339 _, counts = assignment.unique(return_counts=True) 340 if not torch.any(counts <= 2): 341 return assignment 342 retries += 1 343 return assignment 344 345 346 def _init_parameters(self, X): 347 B, N, D = X.shape 348 self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device) 349 for _ in range(self._n_components)] 350 self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device) 351 for _ in range(self._n_components)] 352 self._precisions_cholesky = [torch.empty(B, D, D, 353 dtype=self._dtype, 354 device=self._device) 355 for _ in range(self._n_components)] 356 self._pi = [torch.empty(B, dtype=self._dtype, device=self._device) 357 for _ in range(self._n_components)] 358 359 360 def _m_step(self, X, r, N_actual, N_actual_total, converged): 361 B, N, D = X.shape 362 # We update the means, covariances and weights 363 # for all instances in the batch that still 364 # have not converged. 365 for k in range(self._n_components): 366 self.means[k][~converged] = torch.div( 367 # the nominator is sum(r*X) 368 (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1), 369 # the denominator is normalizing by the number of valid points 370 N_actual[~converged, k].unsqueeze(1)) 371 372 self.covs[k][~converged] = self._get_covs(X[~converged], 373 self.means[k][~converged], 374 r[~converged, :, k], 375 N_actual[~converged, k]) 376 377 # We need to calculate the Cholesky decompositions of 378 # the precision matrices (the precision is the inverse 379 # of the covariance). However, due to numerical errors 380 # the covariance may lose its positive-definite property 381 # (which mathematically is guarenteed to have). Whenever 382 # that happens, we can no longer calculate the Cholesky 383 # decomposition. As a workaround, we substitute the cov 384 # matrix with a near covariance matrix that is positive 385 # definite. 386 covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged]) 387 bad_covs = errors > 0 388 if bad_covs.any(): 389 eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs]) 390 # Theoretically, we should be able to use much smaller 391 # min value here, but for some reason smaller ones sometimes 392 # fail to force the covariance matrix to be positive-definite. 393 new_eigvals = torch.clamp(eigvals, min=1e-5) 394 new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2) 395 self.covs[k][~converged][bad_covs] = new_covs 396 covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs) 397 self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky) 398 399 self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged] 400 401 402 def _kmeans(self, X, n_clusters=2, max_iter=20, tol=0.001, random_seed=None): 403 """ 404 Clusters the points in each instance of the batch using k-means. 405 Points with nan values are assigned with value -1. 406 407 Parameters 408 ---------- 409 X : torch.tensor 410 A tensor with shape (Batch, N-points, Dimensions) 411 n_clusters : int 412 Number of clusters to find. 413 max_iter : int 414 Maximum number of iterations to perform. 415 tol : float 416 The convergence threshold. 417 """ 418 B, N, D = X.shape 419 C = n_clusters 420 valid_points = ~X.isnan().any(2) 421 centers = self._kmeans_pp(X, C, valid_points, random_seed=random_seed) 422 distances = torch.empty(B, N, C, device=self._device) 423 i = 0 424 diff = np.inf 425 while i < max_iter and diff > tol: 426 # Calculate the distance between each point and cluster centers 427 for c in range(C): 428 distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5 429 # Assign each point to the cluster with closest center 430 assignment = distances.argmin(2) 431 # Recalculate cluster centers 432 new_centers = torch.empty(B, C, D, device=self._device) 433 for c in range(C): 434 cluster_mask = (assignment == c).unsqueeze(2).repeat(1, 1, D) 435 new_centers[:, c, :] = torch.where(cluster_mask, X, np.nan).nanmedian(1).values 436 # Estimate how much change we get in the centers 437 diff = (new_centers - centers).mean(1).max() 438 centers = new_centers 439 i += 1 440 for c in range(C): 441 distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5 442 # Assign each point to the cluster with closest center. 443 # Invalid points are assigned -1. 444 assignment = torch.where(valid_points, distances.argmin(2), -1) 445 return centers, assignment 446 447 448 def _kmeans_pp(self, X, C, valid_points, random_seed=None): 449 valid_mask = valid_points.clone() 450 B, N, D = X.shape 451 centers = torch.empty(B, C, D, device=self._device) 452 indices = torch.arange(0, N, device=self._device) 453 rand = np.random.default_rng(random_seed) 454 std = self._nanstd(X) 455 for b in range(B): 456 point_index = rand.choice(indices[valid_mask[b]].cpu().numpy()) 457 point = X[b, point_index, :] 458 point_len = (X[b, point_index, :].pow(2).sum(-1) ** 0.5) 459 centers[b, 0, :] = std[b]*point/point_len 460 valid_mask[b, point_index] = False 461 for k in range(1, C): 462 prev_center = centers[b, k-1, :] 463 distances = ((X[b, valid_mask[b], :] - prev_center) ** 2).sum(-1) ** 0.5 464 # By default kmeans++ takes as the next center the 465 # point that is furthest away. However, if there are 466 # outliers, they're likely to be selected, so here we 467 # ignore the top 10% of the most distant points. 468 max_dist_index = torch.argsort(distances)[math.floor(0.9*distances.shape[0])] 469 470 # The distances are calculated on a subset of the points 471 # so we need to convert the index of the furthest point 472 # to the index of the point in the whole dataset. 473 point_index = indices[valid_mask[b]][max_dist_index] 474 475 # The standard kmeans++ algorithm selects an initial 476 # point at random for the first centroid and then for 477 # each cluster selects the point that is furthest away 478 # from the previous one. This is prone to selecting 479 # outliers that are very far away from all other points, 480 # leading to clusters with a single point. In the GMM 481 # fitting these clusters are problematic, because variance 482 # covariance metrics do not make sense anymore. 483 # To ameliorate this, I position the centroid at a point 484 # that pointing in the direction of the furthest point, 485 # but the length of the vector is equal to the standard 486 # deviation in the dataset. 487 centers[b, k, :] = std[b] * X[b, point_index, :] / distances[max_dist_index] 488 return centers 489 490 491 def _get_covs(self, X, means, r, nums): 492 B, N, D = X.shape 493 # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T) 494 diffs = X - means.unsqueeze(1) 495 summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2)) 496 covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps) 497 return covs 498 499 500 def _get_precisions_cholesky(self, covs_cholesky): 501 B, D, D = covs_cholesky.shape 502 precisions_cholesky = torch.linalg.solve_triangular( 503 covs_cholesky, 504 torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1), 505 upper=False, 506 left=True).permute(0, 2, 1) 507 return precisions_cholesky.to(self._dtype) 508 509 510 def _nanstd(self, X): 511 valid = torch.sum(~X.isnan().any(2), 1) 512 return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
43 def __init__(self, 44 n_components, 45 max_iter=100, 46 device='cuda', 47 tol=0.001, 48 reg_covar=1e-6, 49 means_init=None, 50 weights_init=None, 51 precisions_init=None, 52 dtype=torch.float32, 53 random_seed=None): 54 """ 55 Initialize a Gaussian Mixture Models instance to fit. 56 57 Parameters 58 ---------- 59 n_components : int 60 Number of components (gaussians) in the model. 61 max_iter : int 62 Maximum number of EM iterations to perform. 63 device : torch.device 64 Which device to be used for the computations 65 during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`). 66 tol : float 67 The convergence threshold. 68 reg_covar : float 69 Non-negative regularization added to the diagonal of covariance. 70 Allows to assure that the covariance matrices are all positive. 71 means_init : torch.tensor 72 User provided initialization means for all instances. The 73 tensor should have shape (Batch, Components, Dimensions). 74 If None (default) the means are going to be initialized 75 with modified kmeans++ and then refined with kmeans. 76 weights_init : torch.tensor 77 The user-provided initial weights. The tensor should have shape 78 (Batch, Components). If it is None, weights are initialized 79 depending on the kmeans++ & kmeans initialization. 80 precisions_init : torch.tensor 81 The user-provided initial precisions (inverse of the covariance matrices). 82 The tensor should have shape (Batch, Components, Dimension, Dimension). 83 If it is None, precisions are initialized depending on the kmeans++ & kmeans 84 initialization. 85 dtype : torch.dtype 86 Data type that will be used in the GMM instance. 87 random_seed : int 88 Controls the random seed that will be used 89 when initializing the model parameters. 90 """ 91 self._n_components = n_components 92 self._max_iter = max_iter 93 self._device = device 94 self._tol = tol 95 self._reg_covar = reg_covar 96 self._means_init = means_init 97 self._weights_init = weights_init 98 self._precisions_init = precisions_init 99 self._dtype = dtype 100 self._rand_generator = torch.Generator(device=device) 101 if random_seed: 102 self._rand_seed = random_seed 103 self._rand_generator.manual_seed(random_seed) 104 else: 105 self._rand_seed = None
Initialize a Gaussian Mixture Models instance to fit.
Parameters
- n_components (int): Number of components (gaussians) in the model.
- max_iter (int): Maximum number of EM iterations to perform.
- device (torch.device):
Which device to be used for the computations
during the fitting (e.g
'cpu'
,'cuda'
,'cuda:0'
). - tol (float): The convergence threshold.
- reg_covar (float): Non-negative regularization added to the diagonal of covariance. Allows to assure that the covariance matrices are all positive.
- means_init (torch.tensor): User provided initialization means for all instances. The tensor should have shape (Batch, Components, Dimensions). If None (default) the means are going to be initialized with modified kmeans++ and then refined with kmeans.
- weights_init (torch.tensor): The user-provided initial weights. The tensor should have shape (Batch, Components). If it is None, weights are initialized depending on the kmeans++ & kmeans initialization.
- precisions_init (torch.tensor): The user-provided initial precisions (inverse of the covariance matrices). The tensor should have shape (Batch, Components, Dimension, Dimension). If it is None, precisions are initialized depending on the kmeans++ & kmeans initialization.
- dtype (torch.dtype): Data type that will be used in the GMM instance.
- random_seed (int): Controls the random seed that will be used when initializing the model parameters.
108 def fit(self, X): 109 """ 110 Fit the GMM on the given tensor data. 111 112 Parameters 113 ---------- 114 X : torch.tensor 115 A tensor with shape (Batch, N-points, Dimensions) 116 """ 117 X = X.to(self._dtype) 118 if X.device.type != self._device: 119 X = X.to(self._device) 120 121 B, N, D = X.shape 122 123 self._init_parameters(X) 124 component_mask = self._init_clusters(X) 125 126 r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype) 127 for k in range(self._n_components): 128 r[:, :, k][component_mask == k] = 1 129 130 # This gives us the amount of points per component 131 # for each instance in the batch. It's necessary 132 # in order to handle missing points (with nan values). 133 N_actual = r.nansum(1) 134 N_actual_total = N_actual.sum(1) 135 136 converged = torch.full((B,), False, device=self._device) 137 138 # If we have less than 2 points in a component it produces 139 # bad covariance matrices. Hence, we stop the iterations 140 # for the affected instances and continue with the rest. 141 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 142 converged[single_component] = True 143 144 # If at least one of the parameters is missing 145 # we calculate all parameters with the M-step. 146 if (self._means_init is None or 147 self._weights_init is None or 148 self._precisions_init is None): 149 self._m_step(X, r, N_actual, N_actual_total, converged) 150 151 # If any of the parameters have been provided by the 152 # user, we overwrite it with the provided value. 153 if self._means_init is not None: 154 self.means = [self._means_init[:, c, :] 155 for c in range(self._n_components)] 156 if self._weights_init is not None: 157 self._pi = [self._weights_init[:, c] 158 for c in range(self._n_components)] 159 if self._precisions_init is not None: 160 self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :]) 161 for c in range(self._n_components)] 162 163 self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device) 164 mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device) 165 166 iteration = 1 167 while iteration <= self._max_iter and not converged.all(): 168 prev_mean_log_prob_norm = mean_log_prob_norm.clone() 169 170 # === E-STEP === 171 172 for k in range(self._n_components): 173 r[~converged, :, k] = torch.add( 174 _estimate_gaussian_prob( 175 X[~converged], 176 self.means[k][~converged], 177 self._precisions_cholesky[k][~converged], 178 self._dtype).log(), 179 self._pi[k][~converged].unsqueeze(1).log() 180 ) 181 log_prob_norm = r[~converged].logsumexp(2) 182 r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp() 183 mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1) 184 N_actual = r.nansum(1) 185 186 # If we have less than 2 points in a component it produces 187 # bad covariance matrices. Hence, we stop the iterations 188 # for the affected instances and continue with the rest. 189 single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1) 190 converged[single_component] = True 191 192 # === M-STEP === 193 194 self._m_step(X, r, N_actual, N_actual_total, converged) 195 196 change = mean_log_prob_norm - prev_mean_log_prob_norm 197 198 # If the change for some instances in the batch 199 # are small enough, we mark those instances as 200 # converged and do not process them anymore. 201 small_change = change.abs() < self._tol 202 newly_converged = small_change & ~converged 203 converged[newly_converged] = True 204 self.convergence_iters[newly_converged] = iteration 205 206 iteration += 1
Fit the GMM on the given tensor data.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
209 def predict(self, X): 210 """ 211 Predict the component assignment for the given tensor data. 212 213 Parameters 214 ---------- 215 X : torch.tensor 216 A tensor with shape (Batch, N-points, Dimensions) 217 218 Returns 219 ---------- 220 torch.tensor 221 tensor of shape (B, N) with component ids as values. 222 """ 223 if X.dtype == self._dtype: 224 X = X.to(self._dtype) 225 if X.device.type != self._device: 226 X = X.to(self._device) 227 B, N, D = X.shape 228 probs = torch.zeros(B, N, self._n_components, device=X.device) 229 for k in range(self._n_components): 230 probs[:, :, k] = _estimate_gaussian_prob(X, 231 self.means[k], 232 self._precisions_cholesky[k], 233 self._dtype) 234 return probs.argmax(2).cpu()
Predict the component assignment for the given tensor data.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
- torch.tensor: tensor of shape (B, N) with component ids as values.
237 def score_samples(self, X): 238 """ 239 Compute the log-likelihood of each point across all instances in the batch. 240 241 Parameters 242 ---------- 243 X : torch.tensor 244 A tensor with shape (Batch, N-points, Dimensions) 245 246 Returns 247 ---------- 248 torch.tensor 249 tensor of shape (B, N) with the score for each point in the batch. 250 """ 251 if X.device.type != self._device: 252 X = X.to(self._device) 253 X = X.to(self._dtype) 254 B, N, D = X.shape 255 log_probs = torch.zeros(B, N, self._n_components, device=X.device) 256 for k in range(self._n_components): 257 # Calculate weighted log probabilities 258 log_probs[:, :, k] = torch.add( 259 self._pi[k].log().unsqueeze(1), 260 _estimate_gaussian_prob(X, 261 self.means[k], 262 self._precisions_cholesky[k], 263 self._dtype).log() 264 ) 265 return log_probs.logsumexp(2).cpu()
Compute the log-likelihood of each point across all instances in the batch.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
- torch.tensor: tensor of shape (B, N) with the score for each point in the batch.
268 def score(self, X): 269 """ 270 Compute the per-sample average log-likelihood of each instance in the batch. 271 272 Parameters 273 ---------- 274 X : torch.tensor 275 A tensor with shape (Batch, N-points, Dimensions) 276 277 Returns 278 ---------- 279 torch.tensor 280 tensor of shape (B,) with the log-likelihood for each instance in the batch. 281 """ 282 return self.score_samples(X).nanmean(1).cpu()
Compute the per-sample average log-likelihood of each instance in the batch.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
- torch.tensor: tensor of shape (B,) with the log-likelihood for each instance in the batch.
285 def bic(self, X): 286 """ 287 Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X. 288 289 Parameters 290 ---------- 291 X : torch.tensor 292 A tensor with shape (Batch, N-points, Dimensions) 293 294 Returns 295 ---------- 296 torch.tensor 297 tensor of shape (B,) with the BIC value for each instance in the Batch. 298 """ 299 scores = self.score(X) 300 valid_points = (~X.isnan()).all(2).sum(1) 301 return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points)
Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
Parameters
- X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
- torch.tensor: tensor of shape (B,) with the BIC value for each instance in the Batch.
304 def n_parameters(self): 305 """ 306 Returns the number of free parameters in the model for a single instance of the batch. 307 308 Returns 309 ---------- 310 int 311 number of parameters in the model 312 """ 313 n_features = self.means[0].shape[1] 314 cov_params = self._n_components * n_features * (n_features + 1) / 2.0 315 mean_params = n_features * self._n_components 316 return int(cov_params + mean_params + self._n_components - 1)
Returns the number of free parameters in the model for a single instance of the batch.
Returns
- int: number of parameters in the model