gmm_gpu.gmm

Provides a GMM class for fitting multiple instances of Gaussian Mixture Models .

This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one. You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.

If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe Pomegranate would be a better option.

Example usage:

Import pytorch and the GMM class

>>> from gmm_gpu.gmm import GMM
>>> import torch

Generate some test data: We create a batch of 1000 instances, each with 200 random points. Half of the points are sampled from distribution centered at the origin (0, 0) and the other half from a distribution centered at (1.5, 1.5).

>>> X1 = torch.randn(1000, 100, 2)
>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
>>> X = torch.cat([X1, X2], dim=1)

Fit the model

>>> gmm = GMM(n_components=2, device='cuda')
>>> gmm.fit(X)

Predict the components: This will return a matrix with shape (1000, 200) where each value is the predicted component for the point.

>>> gmm.predict(X)
  1"""
  2Provides a GMM class for fitting multiple instances of `Gaussian Mixture Models <https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model>`_.
  3
  4This may be useful if you have a large number of independent small problems and you want to fit a GMM on each one.
  5You can create a single large 3D tensor (three dimensional matrix) with the data for all your instances (i.e. a batch) and then
  6send the tensor to the GPU and process the whole batch in parallel. This would work best if all the instances have roughly the same number of points.
  7
  8If you have a single big problem (one GMM instance with many points) that you want to fit using the GPU, maybe `Pomegranate <https://github.com/jmschrei/pomegranate>`_ would be a better option.
  9
 10### Example usage:
 11Import pytorch and the GMM class
 12>>> from gmm_gpu.gmm import GMM
 13>>> import torch
 14
 15Generate some test data:
 16We create a batch of 1000 instances, each
 17with 200 random points. Half of the points
 18are sampled from distribution centered at
 19the origin (0, 0) and the other half from
 20a distribution centered at (1.5, 1.5).
 21>>> X1 = torch.randn(1000, 100, 2)
 22>>> X2 = torch.randn(1000, 100, 2) + torch.tensor([1.5, 1.5])
 23>>> X = torch.cat([X1, X2], dim=1)
 24
 25Fit the model
 26>>> gmm = GMM(n_components=2, device='cuda')
 27>>> gmm.fit(X)
 28
 29Predict the components:
 30This will return a matrix with shape (1000, 200) where
 31each value is the predicted component for the point.
 32>>> gmm.predict(X)
 33"""
 34
 35import math
 36
 37import torch
 38import numpy as np
 39
 40
 41class GMM:
 42    def __init__(self,
 43                 n_components,
 44                 max_iter=100,
 45                 device='cuda',
 46                 tol=0.001,
 47                 reg_covar=1e-6,
 48                 means_init=None,
 49                 weights_init=None,
 50                 precisions_init=None,
 51                 dtype=torch.float32,
 52                 random_seed=None):
 53        """
 54        Initialize a Gaussian Mixture Models instance to fit.
 55
 56        Parameters
 57        ----------
 58        n_components : int
 59            Number of components (gaussians) in the model.
 60        max_iter : int
 61            Maximum number of EM iterations to perform.
 62        device : torch.device
 63            Which device to be used for the computations
 64            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 65        tol : float
 66            The convergence threshold.
 67        reg_covar : float
 68            Non-negative regularization added to the diagonal of covariance.
 69            Allows to assure that the covariance matrices are all positive.
 70        means_init : torch.tensor
 71            User provided initialization means for all instances. The
 72            tensor should have shape (Batch, Components, Dimensions).
 73            If None (default) the means are going to be initialized
 74            with modified kmeans++ and then refined with kmeans.
 75        weights_init : torch.tensor
 76            The user-provided initial weights. The tensor should have shape
 77            (Batch, Components). If it is None, weights are initialized
 78            depending on the kmeans++ & kmeans initialization.
 79        precisions_init : torch.tensor
 80            The user-provided initial precisions (inverse of the covariance matrices).
 81            The tensor should have shape (Batch, Components, Dimension, Dimension).
 82            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 83            initialization.
 84        dtype : torch.dtype
 85            Data type that will be used in the GMM instance.
 86        random_seed : int
 87            Controls the random seed that will be used
 88            when initializing the model parameters.
 89        """
 90        self._n_components = n_components
 91        self._max_iter = max_iter
 92        self._device = device
 93        self._tol = tol
 94        self._reg_covar = reg_covar
 95        self._means_init = means_init
 96        self._weights_init = weights_init
 97        self._precisions_init = precisions_init
 98        self._dtype = dtype
 99        self._rand_generator = torch.Generator(device=device)
100        if random_seed:
101            self._rand_seed = random_seed
102            self._rand_generator.manual_seed(random_seed)
103        else:
104            self._rand_seed = None
105
106
107    def fit(self, X):
108        """
109        Fit the GMM on the given tensor data.
110
111        Parameters
112        ----------
113        X : torch.tensor
114            A tensor with shape (Batch, N-points, Dimensions)
115        """
116        X = X.to(self._dtype)
117        if X.device.type != self._device:
118            X = X.to(self._device)
119
120        B, N, D = X.shape
121
122        self._init_parameters(X)
123        component_mask = self._init_clusters(X)
124
125        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
126        for k in range(self._n_components):
127            r[:, :, k][component_mask == k] = 1
128
129        # This gives us the amount of points per component
130        # for each instance in the batch. It's necessary
131        # in order to handle missing points (with nan values).
132        N_actual = r.nansum(1)
133        N_actual_total = N_actual.sum(1)
134
135        converged = torch.full((B,), False, device=self._device)
136
137        # If we have less than 2 points in a component it produces
138        # bad covariance matrices. Hence, we stop the iterations
139        # for the affected instances and continue with the rest.
140        single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
141        converged[single_component] = True
142
143        # If at least one of the parameters is missing
144        # we calculate all parameters with the M-step.
145        if (self._means_init is None or
146            self._weights_init is None or
147            self._precisions_init is None):
148            self._m_step(X, r, N_actual, N_actual_total, converged)
149
150        # If any of the parameters have been provided by the
151        # user, we overwrite it with the provided value.
152        if self._means_init is not None:
153            self.means = [self._means_init[:, c, :]
154                          for c in range(self._n_components)]
155        if self._weights_init is not None:
156            self._pi = [self._weights_init[:, c]
157                        for c in range(self._n_components)]
158        if self._precisions_init is not None:
159            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
160                                         for c in range(self._n_components)]
161
162        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
163        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
164
165        iteration = 1
166        while iteration <= self._max_iter and not converged.all():
167            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
168
169            # === E-STEP ===
170
171            for k in range(self._n_components):
172                r[~converged, :, k] = torch.add(
173                        _estimate_gaussian_prob(
174                            X[~converged],
175                            self.means[k][~converged],
176                            self._precisions_cholesky[k][~converged],
177                            self._dtype).log(),
178                        self._pi[k][~converged].unsqueeze(1).log()
179                    )
180            log_prob_norm = r[~converged].logsumexp(2)
181            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
182            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
183            N_actual = r.nansum(1)
184
185            # If we have less than 2 points in a component it produces
186            # bad covariance matrices. Hence, we stop the iterations
187            # for the affected instances and continue with the rest.
188            single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
189            converged[single_component] = True
190
191            # === M-STEP ===
192
193            self._m_step(X, r, N_actual, N_actual_total, converged)
194
195            change = mean_log_prob_norm - prev_mean_log_prob_norm
196
197            # If the change for some instances in the batch
198            # are small enough, we mark those instances as
199            # converged and do not process them anymore.
200            small_change = change.abs() < self._tol
201            newly_converged = small_change & ~converged
202            converged[newly_converged] = True
203            self.convergence_iters[newly_converged] = iteration
204
205            iteration += 1
206
207
208    def predict(self, X):
209        """
210        Predict the component assignment for the given tensor data.
211
212        Parameters
213        ----------
214        X : torch.tensor
215            A tensor with shape (Batch, N-points, Dimensions)
216
217        Returns
218        ----------
219        torch.tensor
220            tensor of shape (B, N) with component ids as values.
221        """
222        if X.dtype == self._dtype:
223            X = X.to(self._dtype)
224        if X.device.type != self._device:
225            X = X.to(self._device)
226        B, N, D = X.shape
227        probs = torch.zeros(B, N, self._n_components, device=X.device)
228        for k in range(self._n_components):
229            probs[:, :, k] = _estimate_gaussian_prob(X,
230                                                     self.means[k],
231                                                     self._precisions_cholesky[k],
232                                                     self._dtype)
233        return probs.argmax(2).cpu()
234
235
236    def score_samples(self, X):
237        """
238        Compute the log-likelihood of each point across all instances in the batch.
239
240        Parameters
241        ----------
242        X : torch.tensor
243            A tensor with shape (Batch, N-points, Dimensions)
244
245        Returns
246        ----------
247        torch.tensor
248            tensor of shape (B, N) with the score for each point in the batch.
249        """
250        if X.device.type != self._device:
251            X = X.to(self._device)
252        X = X.to(self._dtype)
253        B, N, D = X.shape
254        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
255        for k in range(self._n_components):
256            # Calculate weighted log probabilities
257            log_probs[:, :, k] = torch.add(
258                    self._pi[k].log().unsqueeze(1),
259                    _estimate_gaussian_prob(X,
260                                            self.means[k],
261                                            self._precisions_cholesky[k],
262                                            self._dtype).log()
263                )
264        return log_probs.logsumexp(2).cpu()
265
266
267    def score(self, X):
268        """
269        Compute the per-sample average log-likelihood of each instance in the batch.
270
271        Parameters
272        ----------
273        X : torch.tensor
274            A tensor with shape (Batch, N-points, Dimensions)
275
276        Returns
277        ----------
278        torch.tensor
279            tensor of shape (B,) with the log-likelihood for each instance in the batch.
280        """
281        return self.score_samples(X).nanmean(1).cpu()
282
283
284    def bic(self, X):
285        """
286        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
287
288        Parameters
289        ----------
290        X : torch.tensor
291            A tensor with shape (Batch, N-points, Dimensions)
292
293        Returns
294        ----------
295        torch.tensor
296            tensor of shape (B,) with the BIC value for each instance in the Batch.
297        """
298        scores = self.score(X)
299        valid_points = (~X.isnan()).all(2).sum(1)
300        return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points)
301
302
303    def n_parameters(self):
304        """
305        Returns the number of free parameters in the model for a single instance of the batch.
306
307        Returns
308        ----------
309        int
310            number of parameters in the model
311        """
312        n_features = self.means[0].shape[1]
313        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
314        mean_params = n_features * self._n_components
315        return int(cov_params + mean_params + self._n_components - 1)
316
317
318    def _init_clusters(self, X):
319        """
320        Init the assignment component (cluster) assignment for B sets of N D-dimensional points.
321
322        Parameters
323        ----------
324        X : torch.tensor
325            A tensor with shape (Batch, N-points, Dimensions)
326        """
327        # If the assignment produced by kmeans has a component
328        # with less than two points, we rerun it to get a different
329        # assignment (up to 3 times). Having less than 2 points leads
330        # to bad covariance matrices that produce errors when trying
331        # to decompose/invert them.
332        retries = 0
333        while retries < 3:
334            seed = self._rand_seed + retries if self._rand_seed else None
335            _, assignment = self._kmeans(X,
336                                         self._n_components,
337                                         random_seed=seed)
338            _, counts = assignment.unique(return_counts=True)
339            if not torch.any(counts <= 2):
340                return assignment
341            retries += 1
342        return assignment
343
344
345    def _init_parameters(self, X):
346        B, N, D = X.shape
347        self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device)
348                      for _ in range(self._n_components)]
349        self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device)
350                     for _ in range(self._n_components)]
351        self._precisions_cholesky = [torch.empty(B, D, D,
352                                                 dtype=self._dtype,
353                                                 device=self._device)
354                                     for _ in range(self._n_components)]
355        self._pi = [torch.empty(B, dtype=self._dtype, device=self._device)
356                    for _ in range(self._n_components)]
357
358
359    def _m_step(self, X, r, N_actual, N_actual_total, converged):
360        B, N, D = X.shape
361        # We update the means, covariances and weights
362        # for all instances in the batch that still
363        # have not converged.
364        for k in range(self._n_components):
365            self.means[k][~converged] = torch.div(
366                    # the nominator is sum(r*X)
367                    (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1),
368                    # the denominator is normalizing by the number of valid points
369                    N_actual[~converged, k].unsqueeze(1))
370
371            self.covs[k][~converged] = self._get_covs(X[~converged],
372                                                      self.means[k][~converged],
373                                                      r[~converged, :, k],
374                                                      N_actual[~converged, k])
375
376            # We need to calculate the Cholesky decompositions of
377            # the precision matrices (the precision is the inverse
378            # of the covariance). However, due to numerical errors
379            # the covariance may lose its positive-definite property
380            # (which mathematically is guarenteed to have). Whenever
381            # that happens, we can no longer calculate the Cholesky
382            # decomposition. As a workaround, we substitute the cov
383            # matrix with a near covariance matrix that is positive
384            # definite.
385            covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged])
386            bad_covs = errors > 0
387            if bad_covs.any():
388                eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs])
389                # Theoretically, we should be able to use much smaller
390                # min value here, but for some reason smaller ones sometimes
391                # fail to force the covariance matrix to be positive-definite.
392                new_eigvals = torch.clamp(eigvals, min=1e-5)
393                new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2)
394                self.covs[k][~converged][bad_covs] = new_covs
395                covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs)
396            self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky)
397
398            self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged]
399
400
401    def _kmeans(self, X, n_clusters=2, max_iter=20, tol=0.001, random_seed=None):
402        """
403        Clusters the points in each instance of the batch using k-means.
404        Points with nan values are assigned with value -1.
405
406        Parameters
407        ----------
408        X : torch.tensor
409            A tensor with shape (Batch, N-points, Dimensions)
410        n_clusters : int
411            Number of clusters to find.
412        max_iter : int
413            Maximum number of iterations to perform.
414        tol : float
415            The convergence threshold.
416        """
417        B, N, D = X.shape
418        C = n_clusters
419        valid_points = ~X.isnan().any(2)
420        centers = self._kmeans_pp(X, C, valid_points, random_seed=random_seed)
421        distances = torch.empty(B, N, C, device=self._device)
422        i = 0
423        diff = np.inf
424        while i < max_iter and diff > tol:
425            # Calculate the distance between each point and cluster centers
426            for c in range(C):
427                distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5
428            # Assign each point to the cluster with closest center
429            assignment = distances.argmin(2)
430            # Recalculate cluster centers
431            new_centers = torch.empty(B, C, D, device=self._device)
432            for c in range(C):
433                cluster_mask = (assignment == c).unsqueeze(2).repeat(1, 1, D)
434                new_centers[:, c, :] = torch.where(cluster_mask, X, np.nan).nanmedian(1).values
435            # Estimate how much change we get in the centers
436            diff = (new_centers - centers).mean(1).max()
437            centers = new_centers
438            i += 1
439        for c in range(C):
440            distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5
441        # Assign each point to the cluster with closest center.
442        # Invalid points are assigned -1.
443        assignment = torch.where(valid_points, distances.argmin(2), -1)
444        return centers, assignment
445
446
447    def _kmeans_pp(self, X, C, valid_points, random_seed=None):
448        valid_mask = valid_points.clone()
449        B, N, D = X.shape
450        centers = torch.empty(B, C, D, device=self._device)
451        indices = torch.arange(0, N, device=self._device)
452        rand = np.random.default_rng(random_seed)
453        std = self._nanstd(X)
454        for b in range(B):
455            point_index = rand.choice(indices[valid_mask[b]].cpu().numpy())
456            point = X[b, point_index, :]
457            point_len = (X[b, point_index, :].pow(2).sum(-1) ** 0.5)
458            centers[b, 0, :] = std[b]*point/point_len
459            valid_mask[b, point_index] = False
460            for k in range(1, C):
461                prev_center = centers[b, k-1, :]
462                distances = ((X[b, valid_mask[b], :] - prev_center) ** 2).sum(-1) ** 0.5
463                # By default kmeans++ takes as the next center the
464                # point that is furthest away. However, if there are
465                # outliers, they're likely to be selected, so here we
466                # ignore the top 10% of the most distant points.
467                max_dist_index = torch.argsort(distances)[math.floor(0.9*distances.shape[0])]
468
469                # The distances are calculated on a subset of the points
470                # so we need to convert the index of the furthest point
471                # to the index of the point in the whole dataset.
472                point_index = indices[valid_mask[b]][max_dist_index]
473
474                # The standard kmeans++ algorithm selects an initial
475                # point at random for the first centroid and then for
476                # each cluster selects the point that is furthest away
477                # from the previous one. This is prone to selecting
478                # outliers that are very far away from all other points,
479                # leading to clusters with a single point. In the GMM
480                # fitting these clusters are problematic, because variance
481                # covariance metrics do not make sense anymore.
482                # To ameliorate this, I position the centroid at a point
483                # that pointing in the direction of the furthest point,
484                # but the length of the vector is equal to the standard
485                # deviation in the dataset.
486                centers[b, k, :] = std[b] * X[b, point_index, :] / distances[max_dist_index]
487        return centers
488
489
490    def _get_covs(self, X, means, r, nums):
491        B, N, D = X.shape
492        # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T)
493        diffs = X - means.unsqueeze(1)
494        summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2))
495        covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps)
496        return covs
497
498
499    def _get_precisions_cholesky(self, covs_cholesky):
500        B, D, D = covs_cholesky.shape
501        precisions_cholesky = torch.linalg.solve_triangular(
502                covs_cholesky,
503                torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1),
504                upper=False,
505                left=True).permute(0, 2, 1)
506        return precisions_cholesky.to(self._dtype)
507
508
509    def _nanstd(self, X):
510        valid = torch.sum(~X.isnan().any(2), 1)
511        return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
512
513
514def _estimate_gaussian_prob(X, mean, precisions_chol, dtype):
515    """
516    Compute the probability of a batch of points X under
517    a batch of multivariate normal distributions.
518
519    Parameters
520    ----------
521    X : torch.tensor
522        A tensor with shape (Batch, N-points, Dimensions).
523        Represents a batch of points.
524    mean : torch.tensor
525        The means of the distributions. Shape: (B, D)
526    precisions_chol : torch.tensor
527        Cholesky decompositions of the precisions matrices. Shape: (B, D, D)
528    dtype : torch.dtype
529        Data type of the result
530
531    Returns
532    ----------
533    torch.tensor
534        tensor of shape (B, N) with probabilities
535    """
536    B, N, D = X.shape
537    y = torch.bmm(X, precisions_chol) - torch.bmm(mean.unsqueeze(1), precisions_chol)
538    log_prob = y.pow(2).sum(2)
539    log_det = torch.diagonal(precisions_chol, dim1=1, dim2=2).log().sum(1)
540    return torch.exp(
541            -0.5 * (D * np.log(2 * np.pi) + log_prob) + log_det.unsqueeze(1))
class GMM:
 42class GMM:
 43    def __init__(self,
 44                 n_components,
 45                 max_iter=100,
 46                 device='cuda',
 47                 tol=0.001,
 48                 reg_covar=1e-6,
 49                 means_init=None,
 50                 weights_init=None,
 51                 precisions_init=None,
 52                 dtype=torch.float32,
 53                 random_seed=None):
 54        """
 55        Initialize a Gaussian Mixture Models instance to fit.
 56
 57        Parameters
 58        ----------
 59        n_components : int
 60            Number of components (gaussians) in the model.
 61        max_iter : int
 62            Maximum number of EM iterations to perform.
 63        device : torch.device
 64            Which device to be used for the computations
 65            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 66        tol : float
 67            The convergence threshold.
 68        reg_covar : float
 69            Non-negative regularization added to the diagonal of covariance.
 70            Allows to assure that the covariance matrices are all positive.
 71        means_init : torch.tensor
 72            User provided initialization means for all instances. The
 73            tensor should have shape (Batch, Components, Dimensions).
 74            If None (default) the means are going to be initialized
 75            with modified kmeans++ and then refined with kmeans.
 76        weights_init : torch.tensor
 77            The user-provided initial weights. The tensor should have shape
 78            (Batch, Components). If it is None, weights are initialized
 79            depending on the kmeans++ & kmeans initialization.
 80        precisions_init : torch.tensor
 81            The user-provided initial precisions (inverse of the covariance matrices).
 82            The tensor should have shape (Batch, Components, Dimension, Dimension).
 83            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 84            initialization.
 85        dtype : torch.dtype
 86            Data type that will be used in the GMM instance.
 87        random_seed : int
 88            Controls the random seed that will be used
 89            when initializing the model parameters.
 90        """
 91        self._n_components = n_components
 92        self._max_iter = max_iter
 93        self._device = device
 94        self._tol = tol
 95        self._reg_covar = reg_covar
 96        self._means_init = means_init
 97        self._weights_init = weights_init
 98        self._precisions_init = precisions_init
 99        self._dtype = dtype
100        self._rand_generator = torch.Generator(device=device)
101        if random_seed:
102            self._rand_seed = random_seed
103            self._rand_generator.manual_seed(random_seed)
104        else:
105            self._rand_seed = None
106
107
108    def fit(self, X):
109        """
110        Fit the GMM on the given tensor data.
111
112        Parameters
113        ----------
114        X : torch.tensor
115            A tensor with shape (Batch, N-points, Dimensions)
116        """
117        X = X.to(self._dtype)
118        if X.device.type != self._device:
119            X = X.to(self._device)
120
121        B, N, D = X.shape
122
123        self._init_parameters(X)
124        component_mask = self._init_clusters(X)
125
126        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
127        for k in range(self._n_components):
128            r[:, :, k][component_mask == k] = 1
129
130        # This gives us the amount of points per component
131        # for each instance in the batch. It's necessary
132        # in order to handle missing points (with nan values).
133        N_actual = r.nansum(1)
134        N_actual_total = N_actual.sum(1)
135
136        converged = torch.full((B,), False, device=self._device)
137
138        # If we have less than 2 points in a component it produces
139        # bad covariance matrices. Hence, we stop the iterations
140        # for the affected instances and continue with the rest.
141        single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
142        converged[single_component] = True
143
144        # If at least one of the parameters is missing
145        # we calculate all parameters with the M-step.
146        if (self._means_init is None or
147            self._weights_init is None or
148            self._precisions_init is None):
149            self._m_step(X, r, N_actual, N_actual_total, converged)
150
151        # If any of the parameters have been provided by the
152        # user, we overwrite it with the provided value.
153        if self._means_init is not None:
154            self.means = [self._means_init[:, c, :]
155                          for c in range(self._n_components)]
156        if self._weights_init is not None:
157            self._pi = [self._weights_init[:, c]
158                        for c in range(self._n_components)]
159        if self._precisions_init is not None:
160            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
161                                         for c in range(self._n_components)]
162
163        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
164        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
165
166        iteration = 1
167        while iteration <= self._max_iter and not converged.all():
168            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
169
170            # === E-STEP ===
171
172            for k in range(self._n_components):
173                r[~converged, :, k] = torch.add(
174                        _estimate_gaussian_prob(
175                            X[~converged],
176                            self.means[k][~converged],
177                            self._precisions_cholesky[k][~converged],
178                            self._dtype).log(),
179                        self._pi[k][~converged].unsqueeze(1).log()
180                    )
181            log_prob_norm = r[~converged].logsumexp(2)
182            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
183            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
184            N_actual = r.nansum(1)
185
186            # If we have less than 2 points in a component it produces
187            # bad covariance matrices. Hence, we stop the iterations
188            # for the affected instances and continue with the rest.
189            single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
190            converged[single_component] = True
191
192            # === M-STEP ===
193
194            self._m_step(X, r, N_actual, N_actual_total, converged)
195
196            change = mean_log_prob_norm - prev_mean_log_prob_norm
197
198            # If the change for some instances in the batch
199            # are small enough, we mark those instances as
200            # converged and do not process them anymore.
201            small_change = change.abs() < self._tol
202            newly_converged = small_change & ~converged
203            converged[newly_converged] = True
204            self.convergence_iters[newly_converged] = iteration
205
206            iteration += 1
207
208
209    def predict(self, X):
210        """
211        Predict the component assignment for the given tensor data.
212
213        Parameters
214        ----------
215        X : torch.tensor
216            A tensor with shape (Batch, N-points, Dimensions)
217
218        Returns
219        ----------
220        torch.tensor
221            tensor of shape (B, N) with component ids as values.
222        """
223        if X.dtype == self._dtype:
224            X = X.to(self._dtype)
225        if X.device.type != self._device:
226            X = X.to(self._device)
227        B, N, D = X.shape
228        probs = torch.zeros(B, N, self._n_components, device=X.device)
229        for k in range(self._n_components):
230            probs[:, :, k] = _estimate_gaussian_prob(X,
231                                                     self.means[k],
232                                                     self._precisions_cholesky[k],
233                                                     self._dtype)
234        return probs.argmax(2).cpu()
235
236
237    def score_samples(self, X):
238        """
239        Compute the log-likelihood of each point across all instances in the batch.
240
241        Parameters
242        ----------
243        X : torch.tensor
244            A tensor with shape (Batch, N-points, Dimensions)
245
246        Returns
247        ----------
248        torch.tensor
249            tensor of shape (B, N) with the score for each point in the batch.
250        """
251        if X.device.type != self._device:
252            X = X.to(self._device)
253        X = X.to(self._dtype)
254        B, N, D = X.shape
255        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
256        for k in range(self._n_components):
257            # Calculate weighted log probabilities
258            log_probs[:, :, k] = torch.add(
259                    self._pi[k].log().unsqueeze(1),
260                    _estimate_gaussian_prob(X,
261                                            self.means[k],
262                                            self._precisions_cholesky[k],
263                                            self._dtype).log()
264                )
265        return log_probs.logsumexp(2).cpu()
266
267
268    def score(self, X):
269        """
270        Compute the per-sample average log-likelihood of each instance in the batch.
271
272        Parameters
273        ----------
274        X : torch.tensor
275            A tensor with shape (Batch, N-points, Dimensions)
276
277        Returns
278        ----------
279        torch.tensor
280            tensor of shape (B,) with the log-likelihood for each instance in the batch.
281        """
282        return self.score_samples(X).nanmean(1).cpu()
283
284
285    def bic(self, X):
286        """
287        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
288
289        Parameters
290        ----------
291        X : torch.tensor
292            A tensor with shape (Batch, N-points, Dimensions)
293
294        Returns
295        ----------
296        torch.tensor
297            tensor of shape (B,) with the BIC value for each instance in the Batch.
298        """
299        scores = self.score(X)
300        valid_points = (~X.isnan()).all(2).sum(1)
301        return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points)
302
303
304    def n_parameters(self):
305        """
306        Returns the number of free parameters in the model for a single instance of the batch.
307
308        Returns
309        ----------
310        int
311            number of parameters in the model
312        """
313        n_features = self.means[0].shape[1]
314        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
315        mean_params = n_features * self._n_components
316        return int(cov_params + mean_params + self._n_components - 1)
317
318
319    def _init_clusters(self, X):
320        """
321        Init the assignment component (cluster) assignment for B sets of N D-dimensional points.
322
323        Parameters
324        ----------
325        X : torch.tensor
326            A tensor with shape (Batch, N-points, Dimensions)
327        """
328        # If the assignment produced by kmeans has a component
329        # with less than two points, we rerun it to get a different
330        # assignment (up to 3 times). Having less than 2 points leads
331        # to bad covariance matrices that produce errors when trying
332        # to decompose/invert them.
333        retries = 0
334        while retries < 3:
335            seed = self._rand_seed + retries if self._rand_seed else None
336            _, assignment = self._kmeans(X,
337                                         self._n_components,
338                                         random_seed=seed)
339            _, counts = assignment.unique(return_counts=True)
340            if not torch.any(counts <= 2):
341                return assignment
342            retries += 1
343        return assignment
344
345
346    def _init_parameters(self, X):
347        B, N, D = X.shape
348        self.means = [torch.empty(B, D, dtype=self._dtype, device=self._device)
349                      for _ in range(self._n_components)]
350        self.covs = [torch.empty(B, D, D, dtype=self._dtype, device=self._device)
351                     for _ in range(self._n_components)]
352        self._precisions_cholesky = [torch.empty(B, D, D,
353                                                 dtype=self._dtype,
354                                                 device=self._device)
355                                     for _ in range(self._n_components)]
356        self._pi = [torch.empty(B, dtype=self._dtype, device=self._device)
357                    for _ in range(self._n_components)]
358
359
360    def _m_step(self, X, r, N_actual, N_actual_total, converged):
361        B, N, D = X.shape
362        # We update the means, covariances and weights
363        # for all instances in the batch that still
364        # have not converged.
365        for k in range(self._n_components):
366            self.means[k][~converged] = torch.div(
367                    # the nominator is sum(r*X)
368                    (r[~converged, :, k].unsqueeze(2) * X[~converged]).nansum(1),
369                    # the denominator is normalizing by the number of valid points
370                    N_actual[~converged, k].unsqueeze(1))
371
372            self.covs[k][~converged] = self._get_covs(X[~converged],
373                                                      self.means[k][~converged],
374                                                      r[~converged, :, k],
375                                                      N_actual[~converged, k])
376
377            # We need to calculate the Cholesky decompositions of
378            # the precision matrices (the precision is the inverse
379            # of the covariance). However, due to numerical errors
380            # the covariance may lose its positive-definite property
381            # (which mathematically is guarenteed to have). Whenever
382            # that happens, we can no longer calculate the Cholesky
383            # decomposition. As a workaround, we substitute the cov
384            # matrix with a near covariance matrix that is positive
385            # definite.
386            covs_cholesky, errors = torch.linalg.cholesky_ex(self.covs[k][~converged])
387            bad_covs = errors > 0
388            if bad_covs.any():
389                eigvals, eigvecs = torch.linalg.eigh(self.covs[k][~converged][bad_covs])
390                # Theoretically, we should be able to use much smaller
391                # min value here, but for some reason smaller ones sometimes
392                # fail to force the covariance matrix to be positive-definite.
393                new_eigvals = torch.clamp(eigvals, min=1e-5)
394                new_covs = eigvecs @ torch.diag_embed(new_eigvals) @ eigvecs.transpose(-1, -2)
395                self.covs[k][~converged][bad_covs] = new_covs
396                covs_cholesky[bad_covs] = torch.linalg.cholesky(new_covs)
397            self._precisions_cholesky[k][~converged] = self._get_precisions_cholesky(covs_cholesky)
398
399            self._pi[k][~converged] = N_actual[~converged, k]/N_actual_total[~converged]
400
401
402    def _kmeans(self, X, n_clusters=2, max_iter=20, tol=0.001, random_seed=None):
403        """
404        Clusters the points in each instance of the batch using k-means.
405        Points with nan values are assigned with value -1.
406
407        Parameters
408        ----------
409        X : torch.tensor
410            A tensor with shape (Batch, N-points, Dimensions)
411        n_clusters : int
412            Number of clusters to find.
413        max_iter : int
414            Maximum number of iterations to perform.
415        tol : float
416            The convergence threshold.
417        """
418        B, N, D = X.shape
419        C = n_clusters
420        valid_points = ~X.isnan().any(2)
421        centers = self._kmeans_pp(X, C, valid_points, random_seed=random_seed)
422        distances = torch.empty(B, N, C, device=self._device)
423        i = 0
424        diff = np.inf
425        while i < max_iter and diff > tol:
426            # Calculate the distance between each point and cluster centers
427            for c in range(C):
428                distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5
429            # Assign each point to the cluster with closest center
430            assignment = distances.argmin(2)
431            # Recalculate cluster centers
432            new_centers = torch.empty(B, C, D, device=self._device)
433            for c in range(C):
434                cluster_mask = (assignment == c).unsqueeze(2).repeat(1, 1, D)
435                new_centers[:, c, :] = torch.where(cluster_mask, X, np.nan).nanmedian(1).values
436            # Estimate how much change we get in the centers
437            diff = (new_centers - centers).mean(1).max()
438            centers = new_centers
439            i += 1
440        for c in range(C):
441            distances[:, :, c] = ((X - centers[:, c, :].unsqueeze(1)) ** 2).sum(2) ** 0.5
442        # Assign each point to the cluster with closest center.
443        # Invalid points are assigned -1.
444        assignment = torch.where(valid_points, distances.argmin(2), -1)
445        return centers, assignment
446
447
448    def _kmeans_pp(self, X, C, valid_points, random_seed=None):
449        valid_mask = valid_points.clone()
450        B, N, D = X.shape
451        centers = torch.empty(B, C, D, device=self._device)
452        indices = torch.arange(0, N, device=self._device)
453        rand = np.random.default_rng(random_seed)
454        std = self._nanstd(X)
455        for b in range(B):
456            point_index = rand.choice(indices[valid_mask[b]].cpu().numpy())
457            point = X[b, point_index, :]
458            point_len = (X[b, point_index, :].pow(2).sum(-1) ** 0.5)
459            centers[b, 0, :] = std[b]*point/point_len
460            valid_mask[b, point_index] = False
461            for k in range(1, C):
462                prev_center = centers[b, k-1, :]
463                distances = ((X[b, valid_mask[b], :] - prev_center) ** 2).sum(-1) ** 0.5
464                # By default kmeans++ takes as the next center the
465                # point that is furthest away. However, if there are
466                # outliers, they're likely to be selected, so here we
467                # ignore the top 10% of the most distant points.
468                max_dist_index = torch.argsort(distances)[math.floor(0.9*distances.shape[0])]
469
470                # The distances are calculated on a subset of the points
471                # so we need to convert the index of the furthest point
472                # to the index of the point in the whole dataset.
473                point_index = indices[valid_mask[b]][max_dist_index]
474
475                # The standard kmeans++ algorithm selects an initial
476                # point at random for the first centroid and then for
477                # each cluster selects the point that is furthest away
478                # from the previous one. This is prone to selecting
479                # outliers that are very far away from all other points,
480                # leading to clusters with a single point. In the GMM
481                # fitting these clusters are problematic, because variance
482                # covariance metrics do not make sense anymore.
483                # To ameliorate this, I position the centroid at a point
484                # that pointing in the direction of the furthest point,
485                # but the length of the vector is equal to the standard
486                # deviation in the dataset.
487                centers[b, k, :] = std[b] * X[b, point_index, :] / distances[max_dist_index]
488        return centers
489
490
491    def _get_covs(self, X, means, r, nums):
492        B, N, D = X.shape
493        # C_k = (1/N_k) * sum(r_nk * (x - mu_k)(x - mu_k)^T)
494        diffs = X - means.unsqueeze(1)
495        summands = r.view(B, N, 1, 1) * torch.matmul(diffs.unsqueeze(3), diffs.unsqueeze(2))
496        covs = summands.nansum(1) / nums.view(B, 1, 1).add(torch.finfo(self._dtype).eps)
497        return covs
498
499
500    def _get_precisions_cholesky(self, covs_cholesky):
501        B, D, D = covs_cholesky.shape
502        precisions_cholesky = torch.linalg.solve_triangular(
503                covs_cholesky,
504                torch.eye(D, device=self._device).unsqueeze(0).repeat(B, 1, 1),
505                upper=False,
506                left=True).permute(0, 2, 1)
507        return precisions_cholesky.to(self._dtype)
508
509
510    def _nanstd(self, X):
511        valid = torch.sum(~X.isnan().any(2), 1)
512        return (((X - X.nanmean(1).unsqueeze(1)) ** 2).nansum(1) / valid.unsqueeze(1)) ** 0.5
GMM( n_components, max_iter=100, device='cuda', tol=0.001, reg_covar=1e-06, means_init=None, weights_init=None, precisions_init=None, dtype=torch.float32, random_seed=None)
 43    def __init__(self,
 44                 n_components,
 45                 max_iter=100,
 46                 device='cuda',
 47                 tol=0.001,
 48                 reg_covar=1e-6,
 49                 means_init=None,
 50                 weights_init=None,
 51                 precisions_init=None,
 52                 dtype=torch.float32,
 53                 random_seed=None):
 54        """
 55        Initialize a Gaussian Mixture Models instance to fit.
 56
 57        Parameters
 58        ----------
 59        n_components : int
 60            Number of components (gaussians) in the model.
 61        max_iter : int
 62            Maximum number of EM iterations to perform.
 63        device : torch.device
 64            Which device to be used for the computations
 65            during the fitting (e.g `'cpu'`, `'cuda'`, `'cuda:0'`).
 66        tol : float
 67            The convergence threshold.
 68        reg_covar : float
 69            Non-negative regularization added to the diagonal of covariance.
 70            Allows to assure that the covariance matrices are all positive.
 71        means_init : torch.tensor
 72            User provided initialization means for all instances. The
 73            tensor should have shape (Batch, Components, Dimensions).
 74            If None (default) the means are going to be initialized
 75            with modified kmeans++ and then refined with kmeans.
 76        weights_init : torch.tensor
 77            The user-provided initial weights. The tensor should have shape
 78            (Batch, Components). If it is None, weights are initialized
 79            depending on the kmeans++ & kmeans initialization.
 80        precisions_init : torch.tensor
 81            The user-provided initial precisions (inverse of the covariance matrices).
 82            The tensor should have shape (Batch, Components, Dimension, Dimension).
 83            If it is None, precisions are initialized depending on the kmeans++ & kmeans
 84            initialization.
 85        dtype : torch.dtype
 86            Data type that will be used in the GMM instance.
 87        random_seed : int
 88            Controls the random seed that will be used
 89            when initializing the model parameters.
 90        """
 91        self._n_components = n_components
 92        self._max_iter = max_iter
 93        self._device = device
 94        self._tol = tol
 95        self._reg_covar = reg_covar
 96        self._means_init = means_init
 97        self._weights_init = weights_init
 98        self._precisions_init = precisions_init
 99        self._dtype = dtype
100        self._rand_generator = torch.Generator(device=device)
101        if random_seed:
102            self._rand_seed = random_seed
103            self._rand_generator.manual_seed(random_seed)
104        else:
105            self._rand_seed = None

Initialize a Gaussian Mixture Models instance to fit.

Parameters
  • n_components (int): Number of components (gaussians) in the model.
  • max_iter (int): Maximum number of EM iterations to perform.
  • device (torch.device): Which device to be used for the computations during the fitting (e.g 'cpu', 'cuda', 'cuda:0').
  • tol (float): The convergence threshold.
  • reg_covar (float): Non-negative regularization added to the diagonal of covariance. Allows to assure that the covariance matrices are all positive.
  • means_init (torch.tensor): User provided initialization means for all instances. The tensor should have shape (Batch, Components, Dimensions). If None (default) the means are going to be initialized with modified kmeans++ and then refined with kmeans.
  • weights_init (torch.tensor): The user-provided initial weights. The tensor should have shape (Batch, Components). If it is None, weights are initialized depending on the kmeans++ & kmeans initialization.
  • precisions_init (torch.tensor): The user-provided initial precisions (inverse of the covariance matrices). The tensor should have shape (Batch, Components, Dimension, Dimension). If it is None, precisions are initialized depending on the kmeans++ & kmeans initialization.
  • dtype (torch.dtype): Data type that will be used in the GMM instance.
  • random_seed (int): Controls the random seed that will be used when initializing the model parameters.
def fit(self, X):
108    def fit(self, X):
109        """
110        Fit the GMM on the given tensor data.
111
112        Parameters
113        ----------
114        X : torch.tensor
115            A tensor with shape (Batch, N-points, Dimensions)
116        """
117        X = X.to(self._dtype)
118        if X.device.type != self._device:
119            X = X.to(self._device)
120
121        B, N, D = X.shape
122
123        self._init_parameters(X)
124        component_mask = self._init_clusters(X)
125
126        r = torch.zeros(B, N, self._n_components, device=X.device, dtype=self._dtype)
127        for k in range(self._n_components):
128            r[:, :, k][component_mask == k] = 1
129
130        # This gives us the amount of points per component
131        # for each instance in the batch. It's necessary
132        # in order to handle missing points (with nan values).
133        N_actual = r.nansum(1)
134        N_actual_total = N_actual.sum(1)
135
136        converged = torch.full((B,), False, device=self._device)
137
138        # If we have less than 2 points in a component it produces
139        # bad covariance matrices. Hence, we stop the iterations
140        # for the affected instances and continue with the rest.
141        single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
142        converged[single_component] = True
143
144        # If at least one of the parameters is missing
145        # we calculate all parameters with the M-step.
146        if (self._means_init is None or
147            self._weights_init is None or
148            self._precisions_init is None):
149            self._m_step(X, r, N_actual, N_actual_total, converged)
150
151        # If any of the parameters have been provided by the
152        # user, we overwrite it with the provided value.
153        if self._means_init is not None:
154            self.means = [self._means_init[:, c, :]
155                          for c in range(self._n_components)]
156        if self._weights_init is not None:
157            self._pi = [self._weights_init[:, c]
158                        for c in range(self._n_components)]
159        if self._precisions_init is not None:
160            self._precisions_cholesky = [torch.linalg.cholesky(self._precisions_init[:, c, :, :])
161                                         for c in range(self._n_components)]
162
163        self.convergence_iters = torch.full((B,), -1, dtype=int, device=self._device)
164        mean_log_prob_norm = torch.full((B,), -np.inf, dtype=self._dtype, device=self._device)
165
166        iteration = 1
167        while iteration <= self._max_iter and not converged.all():
168            prev_mean_log_prob_norm = mean_log_prob_norm.clone()
169
170            # === E-STEP ===
171
172            for k in range(self._n_components):
173                r[~converged, :, k] = torch.add(
174                        _estimate_gaussian_prob(
175                            X[~converged],
176                            self.means[k][~converged],
177                            self._precisions_cholesky[k][~converged],
178                            self._dtype).log(),
179                        self._pi[k][~converged].unsqueeze(1).log()
180                    )
181            log_prob_norm = r[~converged].logsumexp(2)
182            r[~converged] = (r[~converged] - log_prob_norm.unsqueeze(2)).exp()
183            mean_log_prob_norm[~converged] = log_prob_norm.nanmean(1)
184            N_actual = r.nansum(1)
185
186            # If we have less than 2 points in a component it produces
187            # bad covariance matrices. Hence, we stop the iterations
188            # for the affected instances and continue with the rest.
189            single_component = (N_actual >= N_actual_total.unsqueeze(1) - 1).any(1)
190            converged[single_component] = True
191
192            # === M-STEP ===
193
194            self._m_step(X, r, N_actual, N_actual_total, converged)
195
196            change = mean_log_prob_norm - prev_mean_log_prob_norm
197
198            # If the change for some instances in the batch
199            # are small enough, we mark those instances as
200            # converged and do not process them anymore.
201            small_change = change.abs() < self._tol
202            newly_converged = small_change & ~converged
203            converged[newly_converged] = True
204            self.convergence_iters[newly_converged] = iteration
205
206            iteration += 1

Fit the GMM on the given tensor data.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
def predict(self, X):
209    def predict(self, X):
210        """
211        Predict the component assignment for the given tensor data.
212
213        Parameters
214        ----------
215        X : torch.tensor
216            A tensor with shape (Batch, N-points, Dimensions)
217
218        Returns
219        ----------
220        torch.tensor
221            tensor of shape (B, N) with component ids as values.
222        """
223        if X.dtype == self._dtype:
224            X = X.to(self._dtype)
225        if X.device.type != self._device:
226            X = X.to(self._device)
227        B, N, D = X.shape
228        probs = torch.zeros(B, N, self._n_components, device=X.device)
229        for k in range(self._n_components):
230            probs[:, :, k] = _estimate_gaussian_prob(X,
231                                                     self.means[k],
232                                                     self._precisions_cholesky[k],
233                                                     self._dtype)
234        return probs.argmax(2).cpu()

Predict the component assignment for the given tensor data.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
  • torch.tensor: tensor of shape (B, N) with component ids as values.
def score_samples(self, X):
237    def score_samples(self, X):
238        """
239        Compute the log-likelihood of each point across all instances in the batch.
240
241        Parameters
242        ----------
243        X : torch.tensor
244            A tensor with shape (Batch, N-points, Dimensions)
245
246        Returns
247        ----------
248        torch.tensor
249            tensor of shape (B, N) with the score for each point in the batch.
250        """
251        if X.device.type != self._device:
252            X = X.to(self._device)
253        X = X.to(self._dtype)
254        B, N, D = X.shape
255        log_probs = torch.zeros(B, N, self._n_components, device=X.device)
256        for k in range(self._n_components):
257            # Calculate weighted log probabilities
258            log_probs[:, :, k] = torch.add(
259                    self._pi[k].log().unsqueeze(1),
260                    _estimate_gaussian_prob(X,
261                                            self.means[k],
262                                            self._precisions_cholesky[k],
263                                            self._dtype).log()
264                )
265        return log_probs.logsumexp(2).cpu()

Compute the log-likelihood of each point across all instances in the batch.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
  • torch.tensor: tensor of shape (B, N) with the score for each point in the batch.
def score(self, X):
268    def score(self, X):
269        """
270        Compute the per-sample average log-likelihood of each instance in the batch.
271
272        Parameters
273        ----------
274        X : torch.tensor
275            A tensor with shape (Batch, N-points, Dimensions)
276
277        Returns
278        ----------
279        torch.tensor
280            tensor of shape (B,) with the log-likelihood for each instance in the batch.
281        """
282        return self.score_samples(X).nanmean(1).cpu()

Compute the per-sample average log-likelihood of each instance in the batch.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
  • torch.tensor: tensor of shape (B,) with the log-likelihood for each instance in the batch.
def bic(self, X):
285    def bic(self, X):
286        """
287        Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.
288
289        Parameters
290        ----------
291        X : torch.tensor
292            A tensor with shape (Batch, N-points, Dimensions)
293
294        Returns
295        ----------
296        torch.tensor
297            tensor of shape (B,) with the BIC value for each instance in the Batch.
298        """
299        scores = self.score(X)
300        valid_points = (~X.isnan()).all(2).sum(1)
301        return -2 * scores * valid_points + self.n_parameters() * np.log(valid_points)

Calculates the BIC (Bayesian Information Criterion) for the model on the dataset X.

Parameters
  • X (torch.tensor): A tensor with shape (Batch, N-points, Dimensions)
Returns
  • torch.tensor: tensor of shape (B,) with the BIC value for each instance in the Batch.
def n_parameters(self):
304    def n_parameters(self):
305        """
306        Returns the number of free parameters in the model for a single instance of the batch.
307
308        Returns
309        ----------
310        int
311            number of parameters in the model
312        """
313        n_features = self.means[0].shape[1]
314        cov_params = self._n_components * n_features * (n_features + 1) / 2.0
315        mean_params = n_features * self._n_components
316        return int(cov_params + mean_params + self._n_components - 1)

Returns the number of free parameters in the model for a single instance of the batch.

Returns
  • int: number of parameters in the model