Coverage for src/diffusionlab/distributions/gmm/gmm.py: 100%
53 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Tuple, cast
3from jax import numpy as jnp, Array
4import jax
5from diffusionlab.distributions.base import Distribution
6from diffusionlab.distributions.gmm.utils import (
7 _logdeth,
8 _lstsq,
9 create_gmm_vector_field_fns,
10)
11from diffusionlab.dynamics import DiffusionProcess
14@dataclass(frozen=True)
15class GMM(Distribution):
16 """
17 Implements a Gaussian Mixture Model (GMM) distribution.
19 The probability measure is given by:
21 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])``
23 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process.
25 Attributes:
26 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters.
28 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components.
29 - ``covs`` (``Array[num_components, data_dim, data_dim]``): The covariance matrices of the GMM components.
30 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components.
32 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused).
33 """
35 def __init__(self, means: Array, covs: Array, priors: Array):
36 """
37 Initializes the GMM distribution.
39 Args:
40 means (``Array[num_components, data_dim]``): Means for each Gaussian component.
41 covs (``Array[num_components, data_dim, data_dim]``): Covariance matrices for each Gaussian component.
42 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1.
43 """
44 eps = cast(float, jnp.finfo(covs.dtype).eps)
45 assert means.ndim == 2
46 num_components, data_dim = means.shape
47 assert covs.shape == (num_components, data_dim, data_dim)
48 assert priors.shape == (num_components,)
49 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps)
50 assert jnp.all(jnp.linalg.eigvalsh(covs) >= -eps * data_dim * data_dim)
52 super().__init__(
53 dist_params={
54 "means": means,
55 "covs": covs,
56 "priors": priors,
57 },
58 dist_hparams={},
59 )
61 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]:
62 """
63 Draws samples from the GMM distribution.
65 Args:
66 key (``Array``): JAX PRNG key for random sampling.
67 num_samples (``int``): The total number of samples to generate.
69 Returns:
70 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn.
71 """
72 num_components, data_dim = self.dist_params["means"].shape
73 key, key_cat, key_norm = jax.random.split(key, 3)
75 # Sample component indices
76 component_indices = jax.random.categorical(
77 key_cat, jnp.log(self.dist_params["priors"]), shape=(num_samples,)
78 ) # Shape: (num_samples,)
80 # Get means and covs for the chosen components
81 chosen_means = self.dist_params["means"][
82 component_indices
83 ] # Shape: (num_samples, data_dim)
84 chosen_covs = self.dist_params["covs"][
85 component_indices
86 ] # Shape: (num_samples, data_dim, data_dim)
88 # Generate random keys for each sample
89 sample_keys = jax.random.split(key_norm, num_samples) # Shape: (num_samples, 2)
91 # Define a function to sample one point given mean, cov, and key
92 def sample_one(mean: Array, cov: Array, single_key: Array) -> Array:
93 # multivariate_normal needs shape=() for a single sample
94 # Input shapes: (data_dim,), (data_dim, data_dim), (2,)
95 # Output shape: (data_dim,)
96 return jax.random.multivariate_normal(
97 single_key, mean, cov, shape=(), method="eigh"
98 )
100 # Vectorize the sampling function over the batch dimension
101 # vmap signature: (Array[N, D], Array[N, D, D], Array[N, K]) -> Array[N, D]
102 vectorized_sampler = jax.vmap(sample_one)
104 # Sample all points
105 samples = vectorized_sampler(
106 chosen_means, chosen_covs, sample_keys
107 ) # Shape: (num_samples, data_dim)
109 return samples, component_indices
111 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
112 """
113 Computes the score vector field ``(∇_x log p_t(x_t))`` for the GMM distribution.
115 This is calculated with respect to the perturbed distribution ``p_t`` induced by the
116 ``diffusion_process`` at time ``t``.
118 Args:
119 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
120 t (``Array[]``): The time step (scalar).
121 diffusion_process (``DiffusionProcess``): The diffusion process definition.
123 Returns:
124 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``.
125 """
126 return gmm_score(
127 x_t,
128 t,
129 diffusion_process,
130 self.dist_params["means"],
131 self.dist_params["covs"],
132 self.dist_params["priors"],
133 )
135 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
136 """
137 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the GMM distribution.
139 This represents the expected original sample ``x_0`` given the noisy observation ``x_t``
140 at time ``t`` under the ``diffusion_process``.
142 Args:
143 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
144 t (``Array[]``): The time step (scalar).
145 diffusion_process (``DiffusionProcess``): The diffusion process definition.
147 Returns:
148 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``.
149 """
150 return gmm_x0(
151 x_t,
152 t,
153 diffusion_process,
154 self.dist_params["means"],
155 self.dist_params["covs"],
156 self.dist_params["priors"],
157 )
159 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
160 """
161 Computes the noise prediction ``ε`` for the GMM distribution.
163 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t``
164 at time ``t`` under the ``diffusion_process``.
166 Args:
167 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
168 t (``Array[]``): The time step (scalar).
169 diffusion_process (``DiffusionProcess``): The diffusion process definition.
171 Returns:
172 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``.
173 """
174 return gmm_eps(
175 x_t,
176 t,
177 diffusion_process,
178 self.dist_params["means"],
179 self.dist_params["covs"],
180 self.dist_params["priors"],
181 )
183 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
184 """
185 Computes the velocity vector field ``v`` for the GMM distribution.
187 This relates to the conditional velocity ``E[dx_t/dt | x_t]`` under the
188 ``diffusion_process``.
190 Args:
191 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
192 t (``Array[]``): The time step (scalar).
193 diffusion_process (``DiffusionProcess``): The diffusion process definition.
195 Returns:
196 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``.
197 """
198 return gmm_v(
199 x_t,
200 t,
201 diffusion_process,
202 self.dist_params["means"],
203 self.dist_params["covs"],
204 self.dist_params["priors"],
205 )
208def gmm_x0(
209 x_t: Array,
210 t: Array,
211 diffusion_process: DiffusionProcess,
212 means: Array,
213 covs: Array,
214 priors: Array,
215) -> Array:
216 """
217 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM.
219 This implements the closed-form solution for the conditional expectation
220 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution
221 defined by ``means``, ``covs``, and ``priors``.
223 Args:
224 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
225 t (``Array[]``): The time step (scalar).
226 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``.
227 means (``Array[num_components, data_dim]``): Means of the GMM components.
228 covs (``Array[num_components, data_dim, data_dim]``): Covariances of the GMM components.
229 priors (``Array[num_components]``): Mixture weights of the GMM components.
231 Returns:
232 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``.
233 """
234 num_components, data_dim = means.shape
235 alpha_t = diffusion_process.alpha(t)
236 sigma_t = diffusion_process.sigma(t)
238 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim)
239 covs_t = jax.vmap(lambda cov: alpha_t**2 * cov + sigma_t**2 * jnp.eye(data_dim))(
240 covs
241 ) # (num_components, data_dim, data_dim)
243 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)(
244 means_t
245 ) # (num_components, data_dim)
246 covs_t_inv_xbars_t = jax.vmap(lambda cov_t, xbar_t: _lstsq(cov_t, xbar_t))(
247 covs_t, xbars_t
248 ) # (num_components, data_dim)
250 log_likelihoods_unnormalized = jax.vmap(
251 lambda xbar_t, cov_t, cov_t_inv_xbar_t: -(1 / 2)
252 * (_logdeth(cov_t) + jnp.sum(xbar_t * cov_t_inv_xbar_t))
253 )(xbars_t, covs_t, covs_t_inv_xbars_t) # (num_components,)
255 log_posterior_unnormalized = (
256 jnp.log(priors) + log_likelihoods_unnormalized
257 ) # (num_components,)
258 posterior_probs = jax.nn.softmax(
259 log_posterior_unnormalized, axis=0
260 ) # (num_components,) sum to 1
262 posterior_means = jax.vmap(
263 lambda mean, cov, cov_t_inv_xbar_t: mean + alpha_t * cov @ cov_t_inv_xbar_t
264 )(means, covs, covs_t_inv_xbars_t) # (num_components, data_dim)
266 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,)
268 return x0_pred
271# Generate eps, score, v functions from gmm_x0
272gmm_eps, gmm_score, gmm_v = create_gmm_vector_field_fns(gmm_x0)