Coverage for src/diffusionlab/distributions/gmm/iso_hom_gmm.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Tuple, cast 

3from jax import numpy as jnp, Array 

4import jax 

5from diffusionlab.distributions.base import Distribution 

6from diffusionlab.distributions.gmm.gmm import GMM 

7from diffusionlab.distributions.gmm.utils import create_gmm_vector_field_fns 

8from diffusionlab.dynamics import DiffusionProcess 

9 

10 

11@dataclass(frozen=True) 

12class IsoHomGMM(Distribution): 

13 """ 

14 Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution. 

15 

16 The probability measure is given by: 

17 

18 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)`` 

19 

20 This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. 

21 

22 Attributes: 

23 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. 

24 

25 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. 

26 - ``variance`` (``Array[]``): The variance of the GMM components. 

27 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. 

28 

29 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). 

30 """ 

31 

32 def __init__(self, means: Array, variance: Array, priors: Array): 

33 """ 

34 Initializes the isotropic homoscedastic GMM distribution. 

35 

36 Args: 

37 means (``Array[num_components, data_dim]``): Means for each Gaussian component. 

38 variance (``Array[]``): Variance for each Gaussian component. 

39 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. 

40 """ 

41 eps = cast(float, jnp.finfo(variance.dtype).eps) 

42 assert means.ndim == 2 

43 num_components, data_dim = means.shape 

44 assert variance.shape == () 

45 assert priors.shape == (num_components,) 

46 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) 

47 assert variance >= -eps 

48 

49 super().__init__( 

50 dist_params={ 

51 "means": means, 

52 "variance": variance, 

53 "priors": priors, 

54 }, 

55 dist_hparams={}, 

56 ) 

57 

58 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: 

59 """ 

60 Draws samples from the isotropic homoscedastic GMM distribution. 

61 

62 Args: 

63 key (``Array``): JAX PRNG key for random sampling. 

64 num_samples (``int``): The total number of samples to generate. 

65 

66 Returns: 

67 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. 

68 """ 

69 num_components, data_dim = self.dist_params["means"].shape 

70 cov = self.dist_params["variance"] * jnp.eye(data_dim) 

71 covs = cov[None, :, :].repeat(num_components, axis=0) 

72 base_gmm = GMM(self.dist_params["means"], covs, self.dist_params["priors"]) 

73 return base_gmm.sample(key, num_samples) 

74 

75 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

76 """ 

77 Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic homoscedastic GMM distribution. 

78 

79 This is calculated with respect to the perturbed distribution p_t induced by the 

80 `diffusion_process` at time `t`. 

81 

82 Args: 

83 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

84 t (``Array[]``): The time step (scalar). 

85 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

86 

87 Returns: 

88 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. 

89 """ 

90 return iso_hom_gmm_score( 

91 x_t, 

92 t, 

93 diffusion_process, 

94 self.dist_params["means"], 

95 self.dist_params["variance"], 

96 self.dist_params["priors"], 

97 ) 

98 

99 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

100 """ 

101 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic homoscedastic GMM distribution. 

102 

103 This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` 

104 at time ``t`` under the ``diffusion_process``. 

105 

106 Args: 

107 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

108 t (``Array[]``): The time step (scalar). 

109 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

110 

111 Returns: 

112 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. 

113 """ 

114 return iso_hom_gmm_x0( 

115 x_t, 

116 t, 

117 diffusion_process, 

118 self.dist_params["means"], 

119 self.dist_params["variance"], 

120 self.dist_params["priors"], 

121 ) 

122 

123 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

124 """ 

125 Computes the noise prediction ``ε`` for the isotropic homoscedastic GMM distribution. 

126 

127 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` 

128 at time ``t`` under the ``diffusion_process``. 

129 

130 Args: 

131 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

132 t (``Array[]``): The time step (scalar). 

133 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

134 

135 Returns: 

136 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. 

137 """ 

138 return iso_hom_gmm_eps( 

139 x_t, 

140 t, 

141 diffusion_process, 

142 self.dist_params["means"], 

143 self.dist_params["variance"], 

144 self.dist_params["priors"], 

145 ) 

146 

147 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

148 """ 

149 Computes the velocity vector field ``v`` for the isotropic homoscedastic GMM distribution. 

150 

151 This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. 

152 

153 Args: 

154 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

155 t (``Array[]``): The time step (scalar). 

156 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

157 

158 Returns: 

159 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. 

160 """ 

161 return iso_hom_gmm_v( 

162 x_t, 

163 t, 

164 diffusion_process, 

165 self.dist_params["means"], 

166 self.dist_params["variance"], 

167 self.dist_params["priors"], 

168 ) 

169 

170 

171def iso_hom_gmm_x0( 

172 x_t: Array, 

173 t: Array, 

174 diffusion_process: DiffusionProcess, 

175 means: Array, 

176 variance: Array, 

177 priors: Array, 

178) -> Array: 

179 """ 

180 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for an isotropic homoscedastic GMM. 

181 

182 This implements the closed-form solution for the conditional expectation 

183 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution 

184 defined by ``means``, ``variance``, and ``priors``. 

185 

186 Args: 

187 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

188 t (``Array[]``): The time step (scalar). 

189 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. 

190 means (``Array[num_components, data_dim]``): Means of the GMM components. 

191 variance (``Array[]``): Covariance of the GMM components. 

192 priors (``Array[num_components]``): Mixture weights of the GMM components. 

193 

194 Returns: 

195 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. 

196 """ 

197 num_components, data_dim = means.shape 

198 alpha_t = diffusion_process.alpha(t) 

199 sigma_t = diffusion_process.sigma(t) 

200 

201 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) 

202 variance_t = alpha_t**2 * variance + sigma_t**2 # (,) 

203 

204 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( 

205 means_t 

206 ) # (num_components, data_dim) 

207 variance_t_inv_xbars_t = jax.vmap(lambda xbar_t: xbar_t / variance_t)( 

208 xbars_t 

209 ) # (num_components, data_dim) 

210 

211 log_likelihoods_unnormalized = jax.vmap( 

212 lambda xbar_t, variance_t_inv_xbar_t: -(1 / 2) 

213 * jnp.sum(xbar_t * variance_t_inv_xbar_t) 

214 )(xbars_t, variance_t_inv_xbars_t) # (num_components,) 

215 

216 log_posterior_unnormalized = ( 

217 jnp.log(priors) + log_likelihoods_unnormalized 

218 ) # (num_components,) 

219 posterior_probs = jax.nn.softmax( 

220 log_posterior_unnormalized, axis=0 

221 ) # (num_components,) sum to 1 

222 

223 posterior_means = jax.vmap( 

224 lambda mean, variance_t_inv_xbar_t: mean 

225 + alpha_t * variance * variance_t_inv_xbar_t 

226 )(means, variance_t_inv_xbars_t) # (num_components, data_dim) 

227 

228 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) 

229 

230 return x0_pred 

231 

232 

233# Generate eps, score, v functions from iso_hom_gmm_x0 

234iso_hom_gmm_eps, iso_hom_gmm_score, iso_hom_gmm_v = create_gmm_vector_field_fns( 

235 iso_hom_gmm_x0 

236)