gmm_gpu.gmm

Provides a GMM class for fitting multiple instances of Gaussian Mixture Models .

This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.

If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe Pomegranate would be a better option.

Example usage:

Import pytorch and the GMM class

>>> from gmm_gpu.gmm import GMM
>>> import torch

Generate some test data: We create a batch of 1000 instances, each with 200 random points. Half of the points are sampled from distribution centered at the origin (0, 0) and the other half from a distribution centered at (1.5, 1.5).

>>> X1 = torch.randn(1000, 100, 2)
>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
>>> X = torch.cat([X1, X2], dim=1)

Fit the model

>>> gmm = GMM(n_components=2, device='cuda')
>>> gmm.fit(X)

Predict the components: This will return a matrix with shape (1000, 200) where each value is the predicted component for the point.

>>> gmm.predict(X)
  1"""
  2Provides a GMM class for fitting multiple instances of `Gaussian Mixture Models <https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model>`_.
  3
  4This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one.
  5You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then
  6send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.
  7
  8If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe `Pomegranate <https://github.com/jmschrei/pomegranate>`_ would be a better option.
  9
 10### Example usage:
 11Import pytorch and the GMM class
 12>>> from gmm_gpu.gmm import GMM
 13>>> import torch
 14
 15Generate some test data:
 16We create a batch of 1000 instances, each
 17with 200 random points. Half of the points
 18are sampled from distribution centered at
 19the origin (0, 0) and the other half from
 20a distribution centered at (1.5, 1.5).
 21>>> X1 = torch.randn(1000, 100, 2)
 22>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
 23>>> X = torch.cat([X1, X2], dim=1)
 24
 25Fit the model
 26>>> gmm = GMM(n_components=2, device='cuda')
 27>>> gmm.fit(X)
 28
 29Predict the components:
 30This will return a matrix with shape (1000, 200) where
 31each value is the predicted component for the point.
 32>>> gmm.predict(X)
 33"""
 34
 35import torch
 36import numpy as np
 37
 38
 39class GMM:
 40    def __init__(self,
 41                 n_components,
 42                 max_iter=100,
 43                 device='cuda',
 44                 tol=0.001,
 45                 reg_covar=1e-6,
 46                 means_init=None,
 47                 weights_init=None,
 48                 precisions_init=None,
 49                 dtype=torch.float32,
 50                 random_seed=None):
 51        """
 52        Initialize a Gaussian Mixture Models instance to fit.
 53
 54        Parameters
 55        ----------
 56        n_components : int
 57            Number of components (gaussians) in the model.
 58        max_iter : int
 59            Maximum number of EM iterations to perform.
 60        device : torch.device
 61            Which device to be used for the computations
 62            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 63        tol : float
 64            The convergence threshold.
 65        reg_covar : float
 66            Non-negative regularization added to the diagonal of covariance.
 67            Allows to assure that the covariance matrices are all positive.
 68        means_init : torch.tensor
 69            User provided initialization means for all instances. The
 70            tensor should have shape (Batch, Components, Dimensions).
 71            If None (default) the means are going to be initialized
 72            with modified kmeans++ and then refined with kmeans.
 73        weights_init : torch.tensor
 74            The user-provided initial weights. The tensor should have shape
 75            (Batch, Components). If it is None, weights are initialized
 76            depending on the kmeans++ & kmeans initialization.
 77        precisions_init : torch.tensor
 78            The user-provided initial precisions (inverse of the covariance matrices).
 79            The tensor should have shape (Batch, Components, Dimension, Dimension).
 80            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 81            initialization.
 82        dtype : torch.dtype
 83            Data type that will be used in the GMM instance.
 84        random_seed : int
 85            Controls the random seed that will be used
 86            when initializing the model parameters.
 87        """
 88        self._n_components = n_components
 89        self._max_iter = max_iter
 90        self._device = device
 91        self._tol = tol
 92        self._reg_covar = reg_covar
 93        self._means_init = means_init
 94        self._weights_init = weights_init
 95        self._precisions_init = precisions_init
 96        self._dtype = dtype
 97        self._rand_generator = torch.Generator(device=device)
 98        if random_seed:
 99            self._rand_seed = random_seed
100            self._rand_generator.manual_seed(random_seed)
101        else:
102            self._rand_seed = None
103
104
105    def fit(self, X):
106        """
107        Fit the GMM on the given tensor data.
108
109        Parameters
110        ----------
111        X : torch.tensor
112            A tensor with shape (Batch, N-points, Dimensions)
113        """
114        X = X.to(device=self._device, dtype=self._dtype)
115
116        B, N, D = X.shape
117
118        self._init_parameters(X)
119        component_mask = self._init_clusters(X)
120
121        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
122        for k in range(self._n_components):
123            r[:, :, k][component_mask == k] = 1
124
125        # This gives us the amount of points per component
126        # for each instance in the batch. It's necessary
127        # in order to handle missing points (with nan values).
128        N_actual = r.nansum(1)
129        N_actual_total = N_actual.sum(1)
130
131        converged = torch.full((B,), False, device=self._device)
132
133        # If at least one of the parameters is missing
134        # we calculate all parameters with the M-step.
135        if (self._means_init is None or
136            self._weights_init is None or
137            self._precisions_init is None):
138            self._m_step(X, r, N_actual, N_actual_total, converged)
139
140        # If any of the parameters have been provided by the
141        # user, we overwrite it with the provided value.
142        if self._means_init is not None:
143            self.means = [self._means_init[:, c, :]
144                          for c in range(self._n_components)]
145        if self._weights_init is not None:
146            self._pi = [self._weights_init[:, c]
147                        for c in range(self._n_components)]
148        if self._precisions_init is not None:
149            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
150                                         for c in range(self._n_components)]
151
152        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
153        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
154
155        iteration = 1
156        while iteration <= self._max_iter and not converged.all():
157            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
158
159            # === E-STEP ===
160
161            for k in range(self._n_components):
162                r[~converged, :, k] = torch.add(
163                        _estimate_gaussian_prob(
164                            X[~converged],
165                            self.means[k][~converged],
166                            self._precisions_cholesky[k][~converged],
167                            self._dtype).log(),
168                        self._pi[k][~converged].unsqueeze(1).log()
169                    )
170            log_prob_norm = r[~converged].logsumexp(2)
171            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
172            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
173            N_actual = r.nansum(1)
174
175            # If we have less than 2 points in a component it produces
176            # bad covariance matrices. Hence, we stop the iterations
177            # for the affected instances and continue with the rest.
178            unprocessable_instances = (N_actual < 2).any(1)
179            converged[unprocessable_instances] = True
180
181            # === M-STEP ===
182
183            self._m_step(X, r, N_actual, N_actual_total, converged)
184
185            change = mean_log_prob_norm - prev_mean_log_prob_norm
186
187            # If the change for some instances in the batch
188            # are small enough, we mark those instances as
189            # converged and do not process them anymore.
190            small_change = change.abs() < self._tol
191            newly_converged = small_change & ~converged
192            converged[newly_converged] = True
193            self.convergence_iters[newly_converged] = iteration
194
195            iteration += 1
196
197
198    def predict_proba(self, X, force_cpu_result=True):
199        """
200        Estimate the components' density for all samples
201        in all instances.
202
203        Parameters
204        ----------
205        X : torch.tensor
206            A tensor with shape (Batch, N-points, Dimensions)
207        force_cpu_result : bool
208            Make sure that the resulting tensor is loaded on
209            the CPU regardless of the device used for the
210            computations (default: True).
211
212        Returns
213        ----------
214        torch.tensor
215            tensor of shape (B, N, n_clusters) with probabilities.
216            The values at positions [I, S, :] will be the probabilities
217            of sample S in instance I to be assigned to each component.
218        """
219        X = X.to(device=self._device, dtype=self._dtype)
220        B, N, D = X.shape
221        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
222        for k in range(self._n_components):
223            # Calculate weighted log probabilities
224            log_probs[:, :, k] = torch.add(
225                    self._pi[k].log().unsqueeze(1),
226                    _estimate_gaussian_prob(X,
227                                            self.means[k],
228                                            self._precisions_cholesky[k],
229                                            self._dtype).log())
230        log_prob_norm = log_probs.logsumexp(2)
231        log_resp = log_probs - log_prob_norm.unsqueeze(2)
232
233        if force_cpu_result:
234            return log_resp.exp().cpu()
235        return log_resp.exp()
236
237
238    def predict(self, X, force_cpu_result=True):
239        """
240        Predict the component assignment for the given tensor data.
241
242        Parameters
243        ----------
244        X : torch.tensor
245            A tensor with shape (Batch, N-points, Dimensions)
246        force_cpu_result : bool
247            Make sure that the resulting tensor is loaded on
248            the CPU regardless of the device used for the
249            computations (default: True).
250
251        Returns
252        ----------
253        torch.tensor
254            tensor of shape (B, N) with component ids as values.
255        """
256        X = X.to(device=self._device, dtype=self._dtype)
257        B, N, D = X.shape
258        probs = torch.zeros(B, N, self._n_components, device=X.device)
259        for k in range(self._n_components):
260            probs[:, :, k] = _estimate_gaussian_prob(X,
261                                                     self.means[k],
262                                                     self._precisions_cholesky[k],
263                                                     self._dtype)
264        if force_cpu_result:
265            torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu()
266        return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2))
267
268
269    def score_samples(self, X, force_cpu_result=True):
270        """
271        Compute the log-likelihood of each point across all instances in the batch.
272
273        Parameters
274        ----------
275        X : torch.tensor
276            A tensor with shape (Batch, N-points, Dimensions)
277        force_cpu_result : bool
278            Make sure that the resulting tensor is loaded on
279            the CPU regardless of the device used for the
280            computations (default: True).
281
282        Returns
283        ----------
284        torch.tensor
285            tensor of shape (B, N) with the score for each point in the batch.
286        """
287        X = X.to(device=self._device, dtype=self._dtype)
288        B, N, D = X.shape
289        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
290        for k in range(self._n_components):
291            # Calculate weighted log probabilities
292            log_probs[:, :, k] = torch.add(
293                    self._pi[k].log().unsqueeze(1),
294                    _estimate_gaussian_prob(X,
295                                            self.means[k],
296                                            self._precisions_cholesky[k],
297                                            self._dtype).log())
298        if force_cpu_result:
299            return log_probs.logsumexp(2).cpu()
300        return log_probs.logsumexp(2)
301
302
303    def score(self, X, force_cpu_result=True):
304        """
305        Compute the per-sample average log-likelihood of each instance in the batch.
306
307        Parameters
308        ----------
309        X : torch.tensor
310            A tensor with shape (Batch, N-points, Dimensions)
311        force_cpu_result : bool
312            Make sure that the resulting tensor is loaded on
313            the CPU regardless of the device used for the
314            computations (default: True).
315
316        Returns
317        ----------
318        torch.tensor
319            tensor of shape (B,) with the log-likelihood for each instance in the batch.
320        """
321        X = X.to(device=self._device, dtype=self._dtype)
322        if force_cpu_result:
323            return self.score_samples(X).nanmean(1).cpu()
324        return self.score_samples(X, force_cpu_result=False).nanmean(1)
325
326
327    def bic(self, X, force_cpu_result=True):
328        """
329        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
330
331        Parameters
332        ----------
333        X : torch.tensor
334            A tensor with shape (Batch, N-points, Dimensions)
335        force_cpu_result : bool
336            Make sure that the resulting tensor is loaded on
337            the CPU regardless of the device used for the
338            computations (default: True).
339
340        Returns
341        ----------
342        torch.tensor
343            tensor of shape (B,) with the BIC value for each instance in the Batch.
344        """
345        X = X.to(device=self._device, dtype=self._dtype)
346        scores = self.score(X, force_cpu_result=False)
347        valid_points = (~X.isnan()).all(2).sum(1)
348        result = -2 * scores * valid_points + self.n_parameters() * valid_points.log()
349        if force_cpu_result:
350            return result.cpu()
351        return result
352
353
354    def n_parameters(self):
355        """
356        Returns the number of free parameters in the model for a single instance of the batch.
357
358        Returns
359        ----------
360        int
361            number of parameters in the model
362        """
363        n_features = self.means[0].shape[1]
364        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
365        mean_params = n_features * self._n_components
366        return int(cov_params + mean_params + self._n_components - 1)
367
368
369    def _init_clusters(self, X):
370        """
371        Init the assignment component (cluster) assignment for B sets of N D-dimensional points.
372
373        Parameters
374        ----------
375        X : torch.tensor
376            A tensor with shape (Batch, N-points, Dimensions)
377        """
378        # If the assignment produced by kmeans has a component
379        # with less than two points, we rerun it to get a different
380        # assignment (up to 3 times). Having less than 2 points leads
381        # to bad covariance matrices that produce errors when trying
382        # to decompose/invert them.
383        retries = 0
384        while retries < 3:
385            _, assignment = self._kmeans(X, self._n_components)
386            _, counts = assignment.unique(return_counts=True)
387            if not torch.any(counts <= 2):
388                return assignment
389            retries += 1
390        return assignment
391
392
393    def _init_parameters(self, X):
394        B, N, D = X.shape
395        self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device)
396                      for _ in range(self._n_components)]
397        self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device)
398                     for _ in range(self._n_components)]
399        self._precisions_cholesky = [torch.empty(B, D, D,
400                                                 dtype=self._dtype,
401                                                 device=self._device)
402                                     for _ in range(self._n_components)]
403        self._pi = [torch.empty(B, dtype=self._dtype, device=self._device)
404                    for _ in range(self._n_components)]
405
406
407    def _m_step(self, X, r, N_actual, N_actual_total, converged):
408        B, N, D = X.shape
409        # We update the means, covariances and weights
410        # for all instances in the batch that still
411        # have not converged.
412        for k in range(self._n_components):
413            self.means[k][~converged] = torch.div(
414                    # the nominator is sum(r*X)
415                    (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1),
416                    # the denominator is normalizing by the number of valid points
417                    N_actual[~converged, k].unsqueeze(1))
418
419            self.covs[k][~converged] = self._get_covs(X[~converged],
420                                                      self.means[k][~converged],
421                                                      r[~converged, :, k],
422                                                      N_actual[~converged, k])
423
424            # We need to calculate the Cholesky decompositions of
425            # the precision matrices (the precision is the inverse
426            # of the covariance). However, due to numerical errors
427            # the covariance may lose its positive-definite property
428            # (which mathematically is guarenteed to have). Whenever
429            # that happens, we can no longer calculate the Cholesky
430            # decomposition. As a workaround, we substitute the cov
431            # matrix with a near covariance matrix that is positive
432            # definite.
433            covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged])
434            bad_covs = errors > 0
435            if bad_covs.any():
436                eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs])
437                # Theoretically, we should be able to use much smaller
438                # min value here, but for some reason smaller ones sometimes
439                # fail to force the covariance matrix to be positive-definite.
440                new_eigvals = torch.clamp(eigvals, min=1e-5)
441                new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2)
442                self.covs[k][~converged][bad_covs] = new_covs
443                covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs)
444            self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky)
445
446            self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged]
447
448
449    def _kmeans(self, X, n_clusters=2, max_iter=10, tol=0.001):
450        """
451        Clusters the points in each instance of the batch using k-means.
452        Points with nan values are assigned with value -1.
453
454        Parameters
455        ----------
456        X : torch.tensor
457            A tensor with shape (Batch, N-points, Dimensions)
458        n_clusters : int
459            Number of clusters to find.
460        max_iter : int
461            Maximum number of iterations to perform.
462        tol : float
463            The convergence threshold.
464        """
465        B, N, D = X.shape
466        C = n_clusters
467        valid_points = ~X.isnan().any(dim=2)
468        invalid_points_count = (~valid_points).sum(1)
469        centers = self._kmeans_pp(X, C, valid_points)
470
471        i = 0
472        diff = np.inf
473        while i < max_iter and diff > tol:
474            # Calculate the squared distance between each point and cluster centers
475            distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1)
476            assignment = distances.argmin(dim=2)
477
478            # Compute the new cluster center
479            cluster_sums = torch.zeros_like(centers)
480            cluster_counts = torch.zeros((B, C, 1), dtype=torch.float32, device=X.device)
481            # The nans are assigned to the first cluster. We want to ignore them.
482            # Hence, we use nat_to_num() to replace them with 0s and then we subtract
483            # the number of invalid points from the counts for the first cluster.
484            cluster_sums.scatter_add_(1, assignment.unsqueeze(-1).expand(-1, -1, D), X.nan_to_num())
485            cluster_counts.scatter_add_(1, assignment.unsqueeze(-1), torch.ones_like(X[:, :, :1]))
486            cluster_counts[:, 0] -= invalid_points_count
487            new_centers = cluster_sums / cluster_counts.clamp_min(1e-8)
488
489            # Estimate how much change we get in the centers
490            diff = torch.norm(new_centers - centers, dim=(1, 2)).max()
491
492            centers = new_centers.nan_to_num()
493            i += 1
494
495        # Final assignment with updated centers
496        distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1)
497        assignment = torch.where(valid_points, distances.argmin(dim=2), -1)
498
499        return centers, assignment
500
501
502    def _select_random_valid_points(self, X, valid_mask):
503        B, N, D = X.shape
504
505        _, point_idx = valid_mask.nonzero(as_tuple=True)
506        counts = valid_mask.sum(1)
507
508        # Select random valid index.
509        # This is efficient, but quite tricky:
510        # nonzero(as_tuple=True) returns a list of the batch indices and corresponding
511        # point indices of valid points. For each instance in the batch, we get a
512        # random integer between 0 and the maximum possible number of valid points.
513        # To make sure that the selected integer is not larger than the number of
514        # valid points for each instance we mod that integer by counts.
515        # This basically gives us a random offset to select a point from a list
516        # of valid points for a given batch index.
517        rand_offsets = torch.randint(0, counts.max(), (B,),
518                                     generator=self._rand_generator,
519                                     device=X.device) % counts
520
521        # Here, cumsum(counts)-counts gives us the starting position of each instance in the batch
522        # in point_idx. E.g. if we have a batch of 3 instances with [5, 7, 3] valid points respectively,
523        # we would get batch starts = [0, 5, 12].
524        batch_starts = torch.cumsum(counts, dim=0) - counts
525        chosen_indices = point_idx[batch_starts + rand_offsets]
526
527        selected_points = X[torch.arange(B, device=X.device), chosen_indices]
528        return selected_points
529
530
531    def _kmeans_pp(self, X, C, valid_points):
532        B, N, D = X.shape
533        device = X.device
534        std = self._nanstd(X)
535        centers = torch.empty(B, C, D, device=device)
536
537        # Randomly select the first center for each batch
538        rand_points = self._select_random_valid_points(X, valid_points)
539        centers[:, 0, :] = std * rand_points / rand_points.norm(dim=-1, keepdim=True)
540
541        # Each subsequent center would be calculated to be distant
542        # from the previous one
543        for k in range(1, C):
544            prev_centers = centers[:, k - 1, :].unsqueeze(1)
545            distances = (X - prev_centers).norm(dim=-1)
546
547            # By default kmeans++ takes as the next center the
548            # point that is furthest away. However, if there are
549            # outliers, they're likely to be selected, so here we
550            # ignore the top 10% of the most distant points.
551            threshold_idx = int(0.9 * N)
552            sorted_distances, sorted_indices = distances.sort(1)
553
554            # The standard kmeans++ algorithm selects an initial
555            # point at random for the first centroid and then for
556            # each cluster selects the point that is furthest away
557            # from the previous one. This is prone to selecting
558            # outliers that are very far away from all other points,
559            # leading to clusters with a single point. In the GMM
560            # fitting these clusters are problematic, because variance
561            # covariance metrics do not make sense anymore.
562            # To ameliorate this, we position the centroid at a point
563            # that is in the direction of the furthest point,
564            # but the length of the vector is equal to the 150% the
565            # standard deviation in the dataset.
566            # First, we get the most distant valid positions (after ignoring
567            # the top 10%).
568            max_valid_idx = _nanmax(sorted_distances[:, :threshold_idx], 1)[1]
569            # Those are indices that point to the sorting and not the original dataset.
570            # We need to map them through sorted_indices to obtain the indices for those points
571            # in the dataset X.
572            orig_indices = sorted_indices[torch.arange(B, device=device), max_valid_idx]
573            selected_points = X[torch.arange(B, device=device), orig_indices]
574            # Once we have the actual points, we calculate the new centers.
575            centers[:, k, :] = 1.5 * std * selected_points / selected_points.norm(dim=-1, keepdim=True)
576        return centers
577
578
579    def _get_covs(self, X, means, r, nums):
580        B, N, D = X.shape
581        # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T)
582        diffs = X - means.unsqueeze(1)
583        summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2))
584        covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps)
585        return covs
586
587
588    def _get_precisions_cholesky(self, covs_cholesky):
589        B, D, D = covs_cholesky.shape
590        precisions_cholesky = torch.linalg.solve_triangular(
591                covs_cholesky,
592                torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1),
593                upper=False,
594                left=True).permute(0, 2, 1)
595        return precisions_cholesky.to(self._dtype)
596
597
598    def _nanstd(self, X):
599        valid = torch.sum(~X.isnan().any(2), 1)
600        return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
601
602
603def _nanmax(T, dim):
604    """
605    Compute the max along a given axis while ignoring NaNs.
606    """
607    nan_mask = T.isnan()
608    T = torch.where(nan_mask, float('-inf'), T)
609    max_values, indices = T.max(dim=dim)
610    return max_values, indices
611
612
613def _estimate_gaussian_prob(X, mean, precisions_chol, dtype):
614    """
615    Compute the probability of a batch of points X under
616    a batch of multivariate normal distributions.
617
618    Parameters
619    ----------
620    X : torch.tensor
621        A tensor with shape (Batch, N-points, Dimensions).
622        Represents a batch of points.
623    mean : torch.tensor
624        The means of the distributions. Shape: (B, D)
625    precisions_chol : torch.tensor
626        Cholesky decompositions of the precisions matrices. Shape: (B, D, D)
627    dtype : torch.dtype
628        Data type of the result
629
630    Returns
631    ----------
632    torch.tensor
633        tensor of shape (B, N) with probabilities
634    """
635    B, N, D = X.shape
636    y = torch.bmm(X, precisions_chol) - torch.bmm(mean.unsqueeze(1), precisions_chol)
637    log_prob = y.pow(2).sum(2)
638    log_det = torch.diagonal(precisions_chol, dim1=1, dim2=2).log().sum(1)
639    return torch.exp(
640            -0.5 * (D * np.log(2 * np.pi) + log_prob) + log_det.unsqueeze(1))
class GMM:
 40class GMM:
 41    def __init__(self,
 42                 n_components,
 43                 max_iter=100,
 44                 device='cuda',
 45                 tol=0.001,
 46                 reg_covar=1e-6,
 47                 means_init=None,
 48                 weights_init=None,
 49                 precisions_init=None,
 50                 dtype=torch.float32,
 51                 random_seed=None):
 52        """
 53        Initialize a Gaussian Mixture Models instance to fit.
 54
 55        Parameters
 56        ----------
 57        n_components : int
 58            Number of components (gaussians) in the model.
 59        max_iter : int
 60            Maximum number of EM iterations to perform.
 61        device : torch.device
 62            Which device to be used for the computations
 63            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 64        tol : float
 65            The convergence threshold.
 66        reg_covar : float
 67            Non-negative regularization added to the diagonal of covariance.
 68            Allows to assure that the covariance matrices are all positive.
 69        means_init : torch.tensor
 70            User provided initialization means for all instances. The
 71            tensor should have shape (Batch, Components, Dimensions).
 72            If None (default) the means are going to be initialized
 73            with modified kmeans++ and then refined with kmeans.
 74        weights_init : torch.tensor
 75            The user-provided initial weights. The tensor should have shape
 76            (Batch, Components). If it is None, weights are initialized
 77            depending on the kmeans++ & kmeans initialization.
 78        precisions_init : torch.tensor
 79            The user-provided initial precisions (inverse of the covariance matrices).
 80            The tensor should have shape (Batch, Components, Dimension, Dimension).
 81            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 82            initialization.
 83        dtype : torch.dtype
 84            Data type that will be used in the GMM instance.
 85        random_seed : int
 86            Controls the random seed that will be used
 87            when initializing the model parameters.
 88        """
 89        self._n_components = n_components
 90        self._max_iter = max_iter
 91        self._device = device
 92        self._tol = tol
 93        self._reg_covar = reg_covar
 94        self._means_init = means_init
 95        self._weights_init = weights_init
 96        self._precisions_init = precisions_init
 97        self._dtype = dtype
 98        self._rand_generator = torch.Generator(device=device)
 99        if random_seed:
100            self._rand_seed = random_seed
101            self._rand_generator.manual_seed(random_seed)
102        else:
103            self._rand_seed = None
104
105
106    def fit(self, X):
107        """
108        Fit the GMM on the given tensor data.
109
110        Parameters
111        ----------
112        X : torch.tensor
113            A tensor with shape (Batch, N-points, Dimensions)
114        """
115        X = X.to(device=self._device, dtype=self._dtype)
116
117        B, N, D = X.shape
118
119        self._init_parameters(X)
120        component_mask = self._init_clusters(X)
121
122        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
123        for k in range(self._n_components):
124            r[:, :, k][component_mask == k] = 1
125
126        # This gives us the amount of points per component
127        # for each instance in the batch. It's necessary
128        # in order to handle missing points (with nan values).
129        N_actual = r.nansum(1)
130        N_actual_total = N_actual.sum(1)
131
132        converged = torch.full((B,), False, device=self._device)
133
134        # If at least one of the parameters is missing
135        # we calculate all parameters with the M-step.
136        if (self._means_init is None or
137            self._weights_init is None or
138            self._precisions_init is None):
139            self._m_step(X, r, N_actual, N_actual_total, converged)
140
141        # If any of the parameters have been provided by the
142        # user, we overwrite it with the provided value.
143        if self._means_init is not None:
144            self.means = [self._means_init[:, c, :]
145                          for c in range(self._n_components)]
146        if self._weights_init is not None:
147            self._pi = [self._weights_init[:, c]
148                        for c in range(self._n_components)]
149        if self._precisions_init is not None:
150            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
151                                         for c in range(self._n_components)]
152
153        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
154        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
155
156        iteration = 1
157        while iteration <= self._max_iter and not converged.all():
158            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
159
160            # === E-STEP ===
161
162            for k in range(self._n_components):
163                r[~converged, :, k] = torch.add(
164                        _estimate_gaussian_prob(
165                            X[~converged],
166                            self.means[k][~converged],
167                            self._precisions_cholesky[k][~converged],
168                            self._dtype).log(),
169                        self._pi[k][~converged].unsqueeze(1).log()
170                    )
171            log_prob_norm = r[~converged].logsumexp(2)
172            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
173            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
174            N_actual = r.nansum(1)
175
176            # If we have less than 2 points in a component it produces
177            # bad covariance matrices. Hence, we stop the iterations
178            # for the affected instances and continue with the rest.
179            unprocessable_instances = (N_actual < 2).any(1)
180            converged[unprocessable_instances] = True
181
182            # === M-STEP ===
183
184            self._m_step(X, r, N_actual, N_actual_total, converged)
185
186            change = mean_log_prob_norm - prev_mean_log_prob_norm
187
188            # If the change for some instances in the batch
189            # are small enough, we mark those instances as
190            # converged and do not process them anymore.
191            small_change = change.abs() < self._tol
192            newly_converged = small_change & ~converged
193            converged[newly_converged] = True
194            self.convergence_iters[newly_converged] = iteration
195
196            iteration += 1
197
198
199    def predict_proba(self, X, force_cpu_result=True):
200        """
201        Estimate the components' density for all samples
202        in all instances.
203
204        Parameters
205        ----------
206        X : torch.tensor
207            A tensor with shape (Batch, N-points, Dimensions)
208        force_cpu_result : bool
209            Make sure that the resulting tensor is loaded on
210            the CPU regardless of the device used for the
211            computations (default: True).
212
213        Returns
214        ----------
215        torch.tensor
216            tensor of shape (B, N, n_clusters) with probabilities.
217            The values at positions [I, S, :] will be the probabilities
218            of sample S in instance I to be assigned to each component.
219        """
220        X = X.to(device=self._device, dtype=self._dtype)
221        B, N, D = X.shape
222        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
223        for k in range(self._n_components):
224            # Calculate weighted log probabilities
225            log_probs[:, :, k] = torch.add(
226                    self._pi[k].log().unsqueeze(1),
227                    _estimate_gaussian_prob(X,
228                                            self.means[k],
229                                            self._precisions_cholesky[k],
230                                            self._dtype).log())
231        log_prob_norm = log_probs.logsumexp(2)
232        log_resp = log_probs - log_prob_norm.unsqueeze(2)
233
234        if force_cpu_result:
235            return log_resp.exp().cpu()
236        return log_resp.exp()
237
238
239    def predict(self, X, force_cpu_result=True):
240        """
241        Predict the component assignment for the given tensor data.
242
243        Parameters
244        ----------
245        X : torch.tensor
246            A tensor with shape (Batch, N-points, Dimensions)
247        force_cpu_result : bool
248            Make sure that the resulting tensor is loaded on
249            the CPU regardless of the device used for the
250            computations (default: True).
251
252        Returns
253        ----------
254        torch.tensor
255            tensor of shape (B, N) with component ids as values.
256        """
257        X = X.to(device=self._device, dtype=self._dtype)
258        B, N, D = X.shape
259        probs = torch.zeros(B, N, self._n_components, device=X.device)
260        for k in range(self._n_components):
261            probs[:, :, k] = _estimate_gaussian_prob(X,
262                                                     self.means[k],
263                                                     self._precisions_cholesky[k],
264                                                     self._dtype)
265        if force_cpu_result:
266            torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu()
267        return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2))
268
269
270    def score_samples(self, X, force_cpu_result=True):
271        """
272        Compute the log-likelihood of each point across all instances in the batch.
273
274        Parameters
275        ----------
276        X : torch.tensor
277            A tensor with shape (Batch, N-points, Dimensions)
278        force_cpu_result : bool
279            Make sure that the resulting tensor is loaded on
280            the CPU regardless of the device used for the
281            computations (default: True).
282
283        Returns
284        ----------
285        torch.tensor
286            tensor of shape (B, N) with the score for each point in the batch.
287        """
288        X = X.to(device=self._device, dtype=self._dtype)
289        B, N, D = X.shape
290        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
291        for k in range(self._n_components):
292            # Calculate weighted log probabilities
293            log_probs[:, :, k] = torch.add(
294                    self._pi[k].log().unsqueeze(1),
295                    _estimate_gaussian_prob(X,
296                                            self.means[k],
297                                            self._precisions_cholesky[k],
298                                            self._dtype).log())
299        if force_cpu_result:
300            return log_probs.logsumexp(2).cpu()
301        return log_probs.logsumexp(2)
302
303
304    def score(self, X, force_cpu_result=True):
305        """
306        Compute the per-sample average log-likelihood of each instance in the batch.
307
308        Parameters
309        ----------
310        X : torch.tensor
311            A tensor with shape (Batch, N-points, Dimensions)
312        force_cpu_result : bool
313            Make sure that the resulting tensor is loaded on
314            the CPU regardless of the device used for the
315            computations (default: True).
316
317        Returns
318        ----------
319        torch.tensor
320            tensor of shape (B,) with the log-likelihood for each instance in the batch.
321        """
322        X = X.to(device=self._device, dtype=self._dtype)
323        if force_cpu_result:
324            return self.score_samples(X).nanmean(1).cpu()
325        return self.score_samples(X, force_cpu_result=False).nanmean(1)
326
327
328    def bic(self, X, force_cpu_result=True):
329        """
330        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
331
332        Parameters
333        ----------
334        X : torch.tensor
335            A tensor with shape (Batch, N-points, Dimensions)
336        force_cpu_result : bool
337            Make sure that the resulting tensor is loaded on
338            the CPU regardless of the device used for the
339            computations (default: True).
340
341        Returns
342        ----------
343        torch.tensor
344            tensor of shape (B,) with the BIC value for each instance in the Batch.
345        """
346        X = X.to(device=self._device, dtype=self._dtype)
347        scores = self.score(X, force_cpu_result=False)
348        valid_points = (~X.isnan()).all(2).sum(1)
349        result = -2 * scores * valid_points + self.n_parameters() * valid_points.log()
350        if force_cpu_result:
351            return result.cpu()
352        return result
353
354
355    def n_parameters(self):
356        """
357        Returns the number of free parameters in the model for a single instance of the batch.
358
359        Returns
360        ----------
361        int
362            number of parameters in the model
363        """
364        n_features = self.means[0].shape[1]
365        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
366        mean_params = n_features * self._n_components
367        return int(cov_params + mean_params + self._n_components - 1)
368
369
370    def _init_clusters(self, X):
371        """
372        Init the assignment component (cluster) assignment for B sets of N D-dimensional points.
373
374        Parameters
375        ----------
376        X : torch.tensor
377            A tensor with shape (Batch, N-points, Dimensions)
378        """
379        # If the assignment produced by kmeans has a component
380        # with less than two points, we rerun it to get a different
381        # assignment (up to 3 times). Having less than 2 points leads
382        # to bad covariance matrices that produce errors when trying
383        # to decompose/invert them.
384        retries = 0
385        while retries < 3:
386            _, assignment = self._kmeans(X, self._n_components)
387            _, counts = assignment.unique(return_counts=True)
388            if not torch.any(counts <= 2):
389                return assignment
390            retries += 1
391        return assignment
392
393
394    def _init_parameters(self, X):
395        B, N, D = X.shape
396        self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device)
397                      for _ in range(self._n_components)]
398        self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device)
399                     for _ in range(self._n_components)]
400        self._precisions_cholesky = [torch.empty(B, D, D,
401                                                 dtype=self._dtype,
402                                                 device=self._device)
403                                     for _ in range(self._n_components)]
404        self._pi = [torch.empty(B, dtype=self._dtype, device=self._device)
405                    for _ in range(self._n_components)]
406
407
408    def _m_step(self, X, r, N_actual, N_actual_total, converged):
409        B, N, D = X.shape
410        # We update the means, covariances and weights
411        # for all instances in the batch that still
412        # have not converged.
413        for k in range(self._n_components):
414            self.means[k][~converged] = torch.div(
415                    # the nominator is sum(r*X)
416                    (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1),
417                    # the denominator is normalizing by the number of valid points
418                    N_actual[~converged, k].unsqueeze(1))
419
420            self.covs[k][~converged] = self._get_covs(X[~converged],
421                                                      self.means[k][~converged],
422                                                      r[~converged, :, k],
423                                                      N_actual[~converged, k])
424
425            # We need to calculate the Cholesky decompositions of
426            # the precision matrices (the precision is the inverse
427            # of the covariance). However, due to numerical errors
428            # the covariance may lose its positive-definite property
429            # (which mathematically is guarenteed to have). Whenever
430            # that happens, we can no longer calculate the Cholesky
431            # decomposition. As a workaround, we substitute the cov
432            # matrix with a near covariance matrix that is positive
433            # definite.
434            covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged])
435            bad_covs = errors > 0
436            if bad_covs.any():
437                eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs])
438                # Theoretically, we should be able to use much smaller
439                # min value here, but for some reason smaller ones sometimes
440                # fail to force the covariance matrix to be positive-definite.
441                new_eigvals = torch.clamp(eigvals, min=1e-5)
442                new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2)
443                self.covs[k][~converged][bad_covs] = new_covs
444                covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs)
445            self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky)
446
447            self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged]
448
449
450    def _kmeans(self, X, n_clusters=2, max_iter=10, tol=0.001):
451        """
452        Clusters the points in each instance of the batch using k-means.
453        Points with nan values are assigned with value -1.
454
455        Parameters
456        ----------
457        X : torch.tensor
458            A tensor with shape (Batch, N-points, Dimensions)
459        n_clusters : int
460            Number of clusters to find.
461        max_iter : int
462            Maximum number of iterations to perform.
463        tol : float
464            The convergence threshold.
465        """
466        B, N, D = X.shape
467        C = n_clusters
468        valid_points = ~X.isnan().any(dim=2)
469        invalid_points_count = (~valid_points).sum(1)
470        centers = self._kmeans_pp(X, C, valid_points)
471
472        i = 0
473        diff = np.inf
474        while i < max_iter and diff > tol:
475            # Calculate the squared distance between each point and cluster centers
476            distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1)
477            assignment = distances.argmin(dim=2)
478
479            # Compute the new cluster center
480            cluster_sums = torch.zeros_like(centers)
481            cluster_counts = torch.zeros((B, C, 1), dtype=torch.float32, device=X.device)
482            # The nans are assigned to the first cluster. We want to ignore them.
483            # Hence, we use nat_to_num() to replace them with 0s and then we subtract
484            # the number of invalid points from the counts for the first cluster.
485            cluster_sums.scatter_add_(1, assignment.unsqueeze(-1).expand(-1, -1, D), X.nan_to_num())
486            cluster_counts.scatter_add_(1, assignment.unsqueeze(-1), torch.ones_like(X[:, :, :1]))
487            cluster_counts[:, 0] -= invalid_points_count
488            new_centers = cluster_sums / cluster_counts.clamp_min(1e-8)
489
490            # Estimate how much change we get in the centers
491            diff = torch.norm(new_centers - centers, dim=(1, 2)).max()
492
493            centers = new_centers.nan_to_num()
494            i += 1
495
496        # Final assignment with updated centers
497        distances = (X[:, :, None, :] - centers[:, None, :, :]).square().sum(dim=-1)
498        assignment = torch.where(valid_points, distances.argmin(dim=2), -1)
499
500        return centers, assignment
501
502
503    def _select_random_valid_points(self, X, valid_mask):
504        B, N, D = X.shape
505
506        _, point_idx = valid_mask.nonzero(as_tuple=True)
507        counts = valid_mask.sum(1)
508
509        # Select random valid index.
510        # This is efficient, but quite tricky:
511        # nonzero(as_tuple=True) returns a list of the batch indices and corresponding
512        # point indices of valid points. For each instance in the batch, we get a
513        # random integer between 0 and the maximum possible number of valid points.
514        # To make sure that the selected integer is not larger than the number of
515        # valid points for each instance we mod that integer by counts.
516        # This basically gives us a random offset to select a point from a list
517        # of valid points for a given batch index.
518        rand_offsets = torch.randint(0, counts.max(), (B,),
519                                     generator=self._rand_generator,
520                                     device=X.device) % counts
521
522        # Here, cumsum(counts)-counts gives us the starting position of each instance in the batch
523        # in point_idx. E.g. if we have a batch of 3 instances with [5, 7, 3] valid points respectively,
524        # we would get batch starts = [0, 5, 12].
525        batch_starts = torch.cumsum(counts, dim=0) - counts
526        chosen_indices = point_idx[batch_starts + rand_offsets]
527
528        selected_points = X[torch.arange(B, device=X.device), chosen_indices]
529        return selected_points
530
531
532    def _kmeans_pp(self, X, C, valid_points):
533        B, N, D = X.shape
534        device = X.device
535        std = self._nanstd(X)
536        centers = torch.empty(B, C, D, device=device)
537
538        # Randomly select the first center for each batch
539        rand_points = self._select_random_valid_points(X, valid_points)
540        centers[:, 0, :] = std * rand_points / rand_points.norm(dim=-1, keepdim=True)
541
542        # Each subsequent center would be calculated to be distant
543        # from the previous one
544        for k in range(1, C):
545            prev_centers = centers[:, k - 1, :].unsqueeze(1)
546            distances = (X - prev_centers).norm(dim=-1)
547
548            # By default kmeans++ takes as the next center the
549            # point that is furthest away. However, if there are
550            # outliers, they're likely to be selected, so here we
551            # ignore the top 10% of the most distant points.
552            threshold_idx = int(0.9 * N)
553            sorted_distances, sorted_indices = distances.sort(1)
554
555            # The standard kmeans++ algorithm selects an initial
556            # point at random for the first centroid and then for
557            # each cluster selects the point that is furthest away
558            # from the previous one. This is prone to selecting
559            # outliers that are very far away from all other points,
560            # leading to clusters with a single point. In the GMM
561            # fitting these clusters are problematic, because variance
562            # covariance metrics do not make sense anymore.
563            # To ameliorate this, we position the centroid at a point
564            # that is in the direction of the furthest point,
565            # but the length of the vector is equal to the 150% the
566            # standard deviation in the dataset.
567            # First, we get the most distant valid positions (after ignoring
568            # the top 10%).
569            max_valid_idx = _nanmax(sorted_distances[:, :threshold_idx], 1)[1]
570            # Those are indices that point to the sorting and not the original dataset.
571            # We need to map them through sorted_indices to obtain the indices for those points
572            # in the dataset X.
573            orig_indices = sorted_indices[torch.arange(B, device=device), max_valid_idx]
574            selected_points = X[torch.arange(B, device=device), orig_indices]
575            # Once we have the actual points, we calculate the new centers.
576            centers[:, k, :] = 1.5 * std * selected_points / selected_points.norm(dim=-1, keepdim=True)
577        return centers
578
579
580    def _get_covs(self, X, means, r, nums):
581        B, N, D = X.shape
582        # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T)
583        diffs = X - means.unsqueeze(1)
584        summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2))
585        covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps)
586        return covs
587
588
589    def _get_precisions_cholesky(self, covs_cholesky):
590        B, D, D = covs_cholesky.shape
591        precisions_cholesky = torch.linalg.solve_triangular(
592                covs_cholesky,
593                torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1),
594                upper=False,
595                left=True).permute(0, 2, 1)
596        return precisions_cholesky.to(self._dtype)
597
598
599    def _nanstd(self, X):
600        valid = torch.sum(~X.isnan().any(2), 1)
601        return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
GMM( n_components, max_iter=100, device='cuda', tol=0.001, reg_covar=1e-06, means_init=None, weights_init=None, precisions_init=None, dtype=torch.float32, random_seed=None)
 41    def __init__(self,
 42                 n_components,
 43                 max_iter=100,
 44                 device='cuda',
 45                 tol=0.001,
 46                 reg_covar=1e-6,
 47                 means_init=None,
 48                 weights_init=None,
 49                 precisions_init=None,
 50                 dtype=torch.float32,
 51                 random_seed=None):
 52        """
 53        Initialize a Gaussian Mixture Models instance to fit.
 54
 55        Parameters
 56        ----------
 57        n_components : int
 58            Number of components (gaussians) in the model.
 59        max_iter : int
 60            Maximum number of EM iterations to perform.
 61        device : torch.device
 62            Which device to be used for the computations
 63            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 64        tol : float
 65            The convergence threshold.
 66        reg_covar : float
 67            Non-negative regularization added to the diagonal of covariance.
 68            Allows to assure that the covariance matrices are all positive.
 69        means_init : torch.tensor
 70            User provided initialization means for all instances. The
 71            tensor should have shape (Batch, Components, Dimensions).
 72            If None (default) the means are going to be initialized
 73            with modified kmeans++ and then refined with kmeans.
 74        weights_init : torch.tensor
 75            The user-provided initial weights. The tensor should have shape
 76            (Batch, Components). If it is None, weights are initialized
 77            depending on the kmeans++ & kmeans initialization.
 78        precisions_init : torch.tensor
 79            The user-provided initial precisions (inverse of the covariance matrices).
 80            The tensor should have shape (Batch, Components, Dimension, Dimension).
 81            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 82            initialization.
 83        dtype : torch.dtype
 84            Data type that will be used in the GMM instance.
 85        random_seed : int
 86            Controls the random seed that will be used
 87            when initializing the model parameters.
 88        """
 89        self._n_components = n_components
 90        self._max_iter = max_iter
 91        self._device = device
 92        self._tol = tol
 93        self._reg_covar = reg_covar
 94        self._means_init = means_init
 95        self._weights_init = weights_init
 96        self._precisions_init = precisions_init
 97        self._dtype = dtype
 98        self._rand_generator = torch.Generator(device=device)
 99        if random_seed:
100            self._rand_seed = random_seed
101            self._rand_generator.manual_seed(random_seed)
102        else:
103            self._rand_seed = None

Initialize a Gaussian Mixture Models instance to fit.

Parameters
  • n_components (int): Number of components (gaussians) in the model.
  • max_iter (int): Maximum number of EM iterations to perform.
  • device (torch.device): Which device to be used for the computations during the fitting (e.g 'cpu', 'cuda', 'cuda:0').
  • tol (float): The convergence threshold.
  • reg_covar (float): Non-negative regularization added to the diagonal of covariance. Allows to assure that the covariance matrices are all positive.
  • means_init (torch.tensor): User provided initialization means for all instances. The tensor should have shape (Batch, Components, Dimensions). If None (default) the means are going to be initialized with modified kmeans++ and then refined with kmeans.
  • weights_init (torch.tensor): The user-provided initial weights. The tensor should have shape (Batch, Components). If it is None, weights are initialized depending on the kmeans++ & kmeans initialization.
  • precisions_init (torch.tensor): The user-provided initial precisions (inverse of the covariance matrices). The tensor should have shape (Batch, Components, Dimension, Dimension). If it is None, precisions are initialized depending on the kmeans++ & kmeans initialization.
  • dtype (torch.dtype): Data type that will be used in the GMM instance.
  • random_seed (int): Controls the random seed that will be used when initializing the model parameters.
def fit(self, X):
106    def fit(self, X):
107        """
108        Fit the GMM on the given tensor data.
109
110        Parameters
111        ----------
112        X : torch.tensor
113            A tensor with shape (Batch, N-points, Dimensions)
114        """
115        X = X.to(device=self._device, dtype=self._dtype)
116
117        B, N, D = X.shape
118
119        self._init_parameters(X)
120        component_mask = self._init_clusters(X)
121
122        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
123        for k in range(self._n_components):
124            r[:, :, k][component_mask == k] = 1
125
126        # This gives us the amount of points per component
127        # for each instance in the batch. It's necessary
128        # in order to handle missing points (with nan values).
129        N_actual = r.nansum(1)
130        N_actual_total = N_actual.sum(1)
131
132        converged = torch.full((B,), False, device=self._device)
133
134        # If at least one of the parameters is missing
135        # we calculate all parameters with the M-step.
136        if (self._means_init is None or
137            self._weights_init is None or
138            self._precisions_init is None):
139            self._m_step(X, r, N_actual, N_actual_total, converged)
140
141        # If any of the parameters have been provided by the
142        # user, we overwrite it with the provided value.
143        if self._means_init is not None:
144            self.means = [self._means_init[:, c, :]
145                          for c in range(self._n_components)]
146        if self._weights_init is not None:
147            self._pi = [self._weights_init[:, c]
148                        for c in range(self._n_components)]
149        if self._precisions_init is not None:
150            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
151                                         for c in range(self._n_components)]
152
153        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
154        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
155
156        iteration = 1
157        while iteration <= self._max_iter and not converged.all():
158            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
159
160            # === E-STEP ===
161
162            for k in range(self._n_components):
163                r[~converged, :, k] = torch.add(
164                        _estimate_gaussian_prob(
165                            X[~converged],
166                            self.means[k][~converged],
167                            self._precisions_cholesky[k][~converged],
168                            self._dtype).log(),
169                        self._pi[k][~converged].unsqueeze(1).log()
170                    )
171            log_prob_norm = r[~converged].logsumexp(2)
172            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
173            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
174            N_actual = r.nansum(1)
175
176            # If we have less than 2 points in a component it produces
177            # bad covariance matrices. Hence, we stop the iterations
178            # for the affected instances and continue with the rest.
179            unprocessable_instances = (N_actual < 2).any(1)
180            converged[unprocessable_instances] = True
181
182            # === M-STEP ===
183
184            self._m_step(X, r, N_actual, N_actual_total, converged)
185
186            change = mean_log_prob_norm - prev_mean_log_prob_norm
187
188            # If the change for some instances in the batch
189            # are small enough, we mark those instances as
190            # converged and do not process them anymore.
191            small_change = change.abs() < self._tol
192            newly_converged = small_change & ~converged
193            converged[newly_converged] = True
194            self.convergence_iters[newly_converged] = iteration
195
196            iteration += 1

Fit the GMM on the given tensor data.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
def predict_proba(self, X, force_cpu_result=True):
199    def predict_proba(self, X, force_cpu_result=True):
200        """
201        Estimate the components' density for all samples
202        in all instances.
203
204        Parameters
205        ----------
206        X : torch.tensor
207            A tensor with shape (Batch, N-points, Dimensions)
208        force_cpu_result : bool
209            Make sure that the resulting tensor is loaded on
210            the CPU regardless of the device used for the
211            computations (default: True).
212
213        Returns
214        ----------
215        torch.tensor
216            tensor of shape (B, N, n_clusters) with probabilities.
217            The values at positions [I, S, :] will be the probabilities
218            of sample S in instance I to be assigned to each component.
219        """
220        X = X.to(device=self._device, dtype=self._dtype)
221        B, N, D = X.shape
222        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
223        for k in range(self._n_components):
224            # Calculate weighted log probabilities
225            log_probs[:, :, k] = torch.add(
226                    self._pi[k].log().unsqueeze(1),
227                    _estimate_gaussian_prob(X,
228                                            self.means[k],
229                                            self._precisions_cholesky[k],
230                                            self._dtype).log())
231        log_prob_norm = log_probs.logsumexp(2)
232        log_resp = log_probs - log_prob_norm.unsqueeze(2)
233
234        if force_cpu_result:
235            return log_resp.exp().cpu()
236        return log_resp.exp()

Estimate the components' density for all samples in all instances.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
  • force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
  • torch.tensor: tensor of shape (B, N, n_clusters) with probabilities. The values at positions [I, S, :] will be the probabilities of sample S in instance I to be assigned to each component.
def predict(self, X, force_cpu_result=True):
239    def predict(self, X, force_cpu_result=True):
240        """
241        Predict the component assignment for the given tensor data.
242
243        Parameters
244        ----------
245        X : torch.tensor
246            A tensor with shape (Batch, N-points, Dimensions)
247        force_cpu_result : bool
248            Make sure that the resulting tensor is loaded on
249            the CPU regardless of the device used for the
250            computations (default: True).
251
252        Returns
253        ----------
254        torch.tensor
255            tensor of shape (B, N) with component ids as values.
256        """
257        X = X.to(device=self._device, dtype=self._dtype)
258        B, N, D = X.shape
259        probs = torch.zeros(B, N, self._n_components, device=X.device)
260        for k in range(self._n_components):
261            probs[:, :, k] = _estimate_gaussian_prob(X,
262                                                     self.means[k],
263                                                     self._precisions_cholesky[k],
264                                                     self._dtype)
265        if force_cpu_result:
266            torch.where(probs.isnan().any(2), np.nan, probs.argmax(2)).cpu()
267        return torch.where(probs.isnan().any(2), np.nan, probs.argmax(2))

Predict the component assignment for the given tensor data.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
  • force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
  • torch.tensor: tensor of shape (B, N) with component ids as values.
def score_samples(self, X, force_cpu_result=True):
270    def score_samples(self, X, force_cpu_result=True):
271        """
272        Compute the log-likelihood of each point across all instances in the batch.
273
274        Parameters
275        ----------
276        X : torch.tensor
277            A tensor with shape (Batch, N-points, Dimensions)
278        force_cpu_result : bool
279            Make sure that the resulting tensor is loaded on
280            the CPU regardless of the device used for the
281            computations (default: True).
282
283        Returns
284        ----------
285        torch.tensor
286            tensor of shape (B, N) with the score for each point in the batch.
287        """
288        X = X.to(device=self._device, dtype=self._dtype)
289        B, N, D = X.shape
290        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
291        for k in range(self._n_components):
292            # Calculate weighted log probabilities
293            log_probs[:, :, k] = torch.add(
294                    self._pi[k].log().unsqueeze(1),
295                    _estimate_gaussian_prob(X,
296                                            self.means[k],
297                                            self._precisions_cholesky[k],
298                                            self._dtype).log())
299        if force_cpu_result:
300            return log_probs.logsumexp(2).cpu()
301        return log_probs.logsumexp(2)

Compute the log-likelihood of each point across all instances in the batch.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
  • force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
  • torch.tensor: tensor of shape (B, N) with the score for each point in the batch.
def score(self, X, force_cpu_result=True):
304    def score(self, X, force_cpu_result=True):
305        """
306        Compute the per-sample average log-likelihood of each instance in the batch.
307
308        Parameters
309        ----------
310        X : torch.tensor
311            A tensor with shape (Batch, N-points, Dimensions)
312        force_cpu_result : bool
313            Make sure that the resulting tensor is loaded on
314            the CPU regardless of the device used for the
315            computations (default: True).
316
317        Returns
318        ----------
319        torch.tensor
320            tensor of shape (B,) with the log-likelihood for each instance in the batch.
321        """
322        X = X.to(device=self._device, dtype=self._dtype)
323        if force_cpu_result:
324            return self.score_samples(X).nanmean(1).cpu()
325        return self.score_samples(X, force_cpu_result=False).nanmean(1)

Compute the per-sample average log-likelihood of each instance in the batch.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
  • force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
  • torch.tensor: tensor of shape (B,) with the log-likelihood for each instance in the batch.
def bic(self, X, force_cpu_result=True):
328    def bic(self, X, force_cpu_result=True):
329        """
330        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
331
332        Parameters
333        ----------
334        X : torch.tensor
335            A tensor with shape (Batch, N-points, Dimensions)
336        force_cpu_result : bool
337            Make sure that the resulting tensor is loaded on
338            the CPU regardless of the device used for the
339            computations (default: True).
340
341        Returns
342        ----------
343        torch.tensor
344            tensor of shape (B,) with the BIC value for each instance in the Batch.
345        """
346        X = X.to(device=self._device, dtype=self._dtype)
347        scores = self.score(X, force_cpu_result=False)
348        valid_points = (~X.isnan()).all(2).sum(1)
349        result = -2 * scores * valid_points + self.n_parameters() * valid_points.log()
350        if force_cpu_result:
351            return result.cpu()
352        return result

Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
  • force_cpu_result (bool): Make sure that the resulting tensor is loaded on the CPU regardless of the device used for the computations (default: True).
Returns
  • torch.tensor: tensor of shape (B,) with the BIC value for each instance in the Batch.
def n_parameters(self):
355    def n_parameters(self):
356        """
357        Returns the number of free parameters in the model for a single instance of the batch.
358
359        Returns
360        ----------
361        int
362            number of parameters in the model
363        """
364        n_features = self.means[0].shape[1]
365        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
366        mean_params = n_features * self._n_components
367        return int(cov_params + mean_params + self._n_components - 1)

Returns the number of free parameters in the model for a single instance of the batch.

Returns
  • int: number of parameters in the model