Coverage for src/diffusionlab/distributions/gmm/iso_hom_gmm.py: 100%
48 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Tuple, cast
3from jax import numpy as jnp, Array
4import jax
5from diffusionlab.distributions.base import Distribution
6from diffusionlab.distributions.gmm.gmm import GMM
7from diffusionlab.distributions.gmm.utils import create_gmm_vector_field_fns
8from diffusionlab.dynamics import DiffusionProcess
11@dataclass(frozen=True)
12class IsoHomGMM(Distribution):
13 """
14 Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution.
16 The probability measure is given by:
18 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)``
20 This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process.
22 Attributes:
23 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters.
25 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components.
26 - ``variance`` (``Array[]``): The variance of the GMM components.
27 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components.
29 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused).
30 """
32 def __init__(self, means: Array, variance: Array, priors: Array):
33 """
34 Initializes the isotropic homoscedastic GMM distribution.
36 Args:
37 means (``Array[num_components, data_dim]``): Means for each Gaussian component.
38 variance (``Array[]``): Variance for each Gaussian component.
39 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1.
40 """
41 eps = cast(float, jnp.finfo(variance.dtype).eps)
42 assert means.ndim == 2
43 num_components, data_dim = means.shape
44 assert variance.shape == ()
45 assert priors.shape == (num_components,)
46 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps)
47 assert variance >= -eps
49 super().__init__(
50 dist_params={
51 "means": means,
52 "variance": variance,
53 "priors": priors,
54 },
55 dist_hparams={},
56 )
58 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]:
59 """
60 Draws samples from the isotropic homoscedastic GMM distribution.
62 Args:
63 key (``Array``): JAX PRNG key for random sampling.
64 num_samples (``int``): The total number of samples to generate.
66 Returns:
67 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn.
68 """
69 num_components, data_dim = self.dist_params["means"].shape
70 cov = self.dist_params["variance"] * jnp.eye(data_dim)
71 covs = cov[None, :, :].repeat(num_components, axis=0)
72 base_gmm = GMM(self.dist_params["means"], covs, self.dist_params["priors"])
73 return base_gmm.sample(key, num_samples)
75 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
76 """
77 Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic homoscedastic GMM distribution.
79 This is calculated with respect to the perturbed distribution p_t induced by the
80 `diffusion_process` at time `t`.
82 Args:
83 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
84 t (``Array[]``): The time step (scalar).
85 diffusion_process (``DiffusionProcess``): The diffusion process definition.
87 Returns:
88 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``.
89 """
90 return iso_hom_gmm_score(
91 x_t,
92 t,
93 diffusion_process,
94 self.dist_params["means"],
95 self.dist_params["variance"],
96 self.dist_params["priors"],
97 )
99 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
100 """
101 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic homoscedastic GMM distribution.
103 This represents the expected original sample ``x_0`` given the noisy observation ``x_t``
104 at time ``t`` under the ``diffusion_process``.
106 Args:
107 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
108 t (``Array[]``): The time step (scalar).
109 diffusion_process (``DiffusionProcess``): The diffusion process definition.
111 Returns:
112 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``.
113 """
114 return iso_hom_gmm_x0(
115 x_t,
116 t,
117 diffusion_process,
118 self.dist_params["means"],
119 self.dist_params["variance"],
120 self.dist_params["priors"],
121 )
123 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
124 """
125 Computes the noise prediction ``ε`` for the isotropic homoscedastic GMM distribution.
127 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t``
128 at time ``t`` under the ``diffusion_process``.
130 Args:
131 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
132 t (``Array[]``): The time step (scalar).
133 diffusion_process (``DiffusionProcess``): The diffusion process definition.
135 Returns:
136 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``.
137 """
138 return iso_hom_gmm_eps(
139 x_t,
140 t,
141 diffusion_process,
142 self.dist_params["means"],
143 self.dist_params["variance"],
144 self.dist_params["priors"],
145 )
147 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
148 """
149 Computes the velocity vector field ``v`` for the isotropic homoscedastic GMM distribution.
151 This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``.
153 Args:
154 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
155 t (``Array[]``): The time step (scalar).
156 diffusion_process (``DiffusionProcess``): The diffusion process definition.
158 Returns:
159 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``.
160 """
161 return iso_hom_gmm_v(
162 x_t,
163 t,
164 diffusion_process,
165 self.dist_params["means"],
166 self.dist_params["variance"],
167 self.dist_params["priors"],
168 )
171def iso_hom_gmm_x0(
172 x_t: Array,
173 t: Array,
174 diffusion_process: DiffusionProcess,
175 means: Array,
176 variance: Array,
177 priors: Array,
178) -> Array:
179 """
180 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for an isotropic homoscedastic GMM.
182 This implements the closed-form solution for the conditional expectation
183 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution
184 defined by ``means``, ``variance``, and ``priors``.
186 Args:
187 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
188 t (``Array[]``): The time step (scalar).
189 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``.
190 means (``Array[num_components, data_dim]``): Means of the GMM components.
191 variance (``Array[]``): Covariance of the GMM components.
192 priors (``Array[num_components]``): Mixture weights of the GMM components.
194 Returns:
195 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``.
196 """
197 num_components, data_dim = means.shape
198 alpha_t = diffusion_process.alpha(t)
199 sigma_t = diffusion_process.sigma(t)
201 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim)
202 variance_t = alpha_t**2 * variance + sigma_t**2 # (,)
204 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)(
205 means_t
206 ) # (num_components, data_dim)
207 variance_t_inv_xbars_t = jax.vmap(lambda xbar_t: xbar_t / variance_t)(
208 xbars_t
209 ) # (num_components, data_dim)
211 log_likelihoods_unnormalized = jax.vmap(
212 lambda xbar_t, variance_t_inv_xbar_t: -(1 / 2)
213 * jnp.sum(xbar_t * variance_t_inv_xbar_t)
214 )(xbars_t, variance_t_inv_xbars_t) # (num_components,)
216 log_posterior_unnormalized = (
217 jnp.log(priors) + log_likelihoods_unnormalized
218 ) # (num_components,)
219 posterior_probs = jax.nn.softmax(
220 log_posterior_unnormalized, axis=0
221 ) # (num_components,) sum to 1
223 posterior_means = jax.vmap(
224 lambda mean, variance_t_inv_xbar_t: mean
225 + alpha_t * variance * variance_t_inv_xbar_t
226 )(means, variance_t_inv_xbars_t) # (num_components, data_dim)
228 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,)
230 return x0_pred
233# Generate eps, score, v functions from iso_hom_gmm_x0
234iso_hom_gmm_eps, iso_hom_gmm_score, iso_hom_gmm_v = create_gmm_vector_field_fns(
235 iso_hom_gmm_x0
236)