Coverage for src/diffusionlab/distributions/gmm/gmm.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Tuple, cast 

3from jax import numpy as jnp, Array 

4import jax 

5from diffusionlab.distributions.base import Distribution 

6from diffusionlab.distributions.gmm.utils import ( 

7 _logdeth, 

8 _lstsq, 

9 create_gmm_vector_field_fns, 

10) 

11from diffusionlab.dynamics import DiffusionProcess 

12 

13 

14@dataclass(frozen=True) 

15class GMM(Distribution): 

16 """ 

17 Implements a Gaussian Mixture Model (GMM) distribution. 

18 

19 The probability measure is given by: 

20 

21 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])`` 

22 

23 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. 

24 

25 Attributes: 

26 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. 

27 

28 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. 

29 - ``covs`` (``Array[num_components, data_dim, data_dim]``): The covariance matrices of the GMM components. 

30 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. 

31 

32 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). 

33 """ 

34 

35 def __init__(self, means: Array, covs: Array, priors: Array): 

36 """ 

37 Initializes the GMM distribution. 

38 

39 Args: 

40 means (``Array[num_components, data_dim]``): Means for each Gaussian component. 

41 covs (``Array[num_components, data_dim, data_dim]``): Covariance matrices for each Gaussian component. 

42 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. 

43 """ 

44 eps = cast(float, jnp.finfo(covs.dtype).eps) 

45 assert means.ndim == 2 

46 num_components, data_dim = means.shape 

47 assert covs.shape == (num_components, data_dim, data_dim) 

48 assert priors.shape == (num_components,) 

49 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) 

50 assert jnp.all(jnp.linalg.eigvalsh(covs) >= -eps * data_dim * data_dim) 

51 

52 super().__init__( 

53 dist_params={ 

54 "means": means, 

55 "covs": covs, 

56 "priors": priors, 

57 }, 

58 dist_hparams={}, 

59 ) 

60 

61 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: 

62 """ 

63 Draws samples from the GMM distribution. 

64 

65 Args: 

66 key (``Array``): JAX PRNG key for random sampling. 

67 num_samples (``int``): The total number of samples to generate. 

68 

69 Returns: 

70 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. 

71 """ 

72 num_components, data_dim = self.dist_params["means"].shape 

73 key, key_cat, key_norm = jax.random.split(key, 3) 

74 

75 # Sample component indices 

76 component_indices = jax.random.categorical( 

77 key_cat, jnp.log(self.dist_params["priors"]), shape=(num_samples,) 

78 ) # Shape: (num_samples,) 

79 

80 # Get means and covs for the chosen components 

81 chosen_means = self.dist_params["means"][ 

82 component_indices 

83 ] # Shape: (num_samples, data_dim) 

84 chosen_covs = self.dist_params["covs"][ 

85 component_indices 

86 ] # Shape: (num_samples, data_dim, data_dim) 

87 

88 # Generate random keys for each sample 

89 sample_keys = jax.random.split(key_norm, num_samples) # Shape: (num_samples, 2) 

90 

91 # Define a function to sample one point given mean, cov, and key 

92 def sample_one(mean: Array, cov: Array, single_key: Array) -> Array: 

93 # multivariate_normal needs shape=() for a single sample 

94 # Input shapes: (data_dim,), (data_dim, data_dim), (2,) 

95 # Output shape: (data_dim,) 

96 return jax.random.multivariate_normal( 

97 single_key, mean, cov, shape=(), method="eigh" 

98 ) 

99 

100 # Vectorize the sampling function over the batch dimension 

101 # vmap signature: (Array[N, D], Array[N, D, D], Array[N, K]) -> Array[N, D] 

102 vectorized_sampler = jax.vmap(sample_one) 

103 

104 # Sample all points 

105 samples = vectorized_sampler( 

106 chosen_means, chosen_covs, sample_keys 

107 ) # Shape: (num_samples, data_dim) 

108 

109 return samples, component_indices 

110 

111 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

112 """ 

113 Computes the score vector field ``(∇_x log p_t(x_t))`` for the GMM distribution. 

114 

115 This is calculated with respect to the perturbed distribution ``p_t`` induced by the 

116 ``diffusion_process`` at time ``t``. 

117 

118 Args: 

119 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

120 t (``Array[]``): The time step (scalar). 

121 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

122 

123 Returns: 

124 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. 

125 """ 

126 return gmm_score( 

127 x_t, 

128 t, 

129 diffusion_process, 

130 self.dist_params["means"], 

131 self.dist_params["covs"], 

132 self.dist_params["priors"], 

133 ) 

134 

135 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

136 """ 

137 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the GMM distribution. 

138 

139 This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` 

140 at time ``t`` under the ``diffusion_process``. 

141 

142 Args: 

143 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

144 t (``Array[]``): The time step (scalar). 

145 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

146 

147 Returns: 

148 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. 

149 """ 

150 return gmm_x0( 

151 x_t, 

152 t, 

153 diffusion_process, 

154 self.dist_params["means"], 

155 self.dist_params["covs"], 

156 self.dist_params["priors"], 

157 ) 

158 

159 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

160 """ 

161 Computes the noise prediction ``ε`` for the GMM distribution. 

162 

163 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` 

164 at time ``t`` under the ``diffusion_process``. 

165 

166 Args: 

167 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

168 t (``Array[]``): The time step (scalar). 

169 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

170 

171 Returns: 

172 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. 

173 """ 

174 return gmm_eps( 

175 x_t, 

176 t, 

177 diffusion_process, 

178 self.dist_params["means"], 

179 self.dist_params["covs"], 

180 self.dist_params["priors"], 

181 ) 

182 

183 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

184 """ 

185 Computes the velocity vector field ``v`` for the GMM distribution. 

186 

187 This relates to the conditional velocity ``E[dx_t/dt | x_t]`` under the 

188 ``diffusion_process``. 

189 

190 Args: 

191 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

192 t (``Array[]``): The time step (scalar). 

193 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

194 

195 Returns: 

196 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. 

197 """ 

198 return gmm_v( 

199 x_t, 

200 t, 

201 diffusion_process, 

202 self.dist_params["means"], 

203 self.dist_params["covs"], 

204 self.dist_params["priors"], 

205 ) 

206 

207 

208def gmm_x0( 

209 x_t: Array, 

210 t: Array, 

211 diffusion_process: DiffusionProcess, 

212 means: Array, 

213 covs: Array, 

214 priors: Array, 

215) -> Array: 

216 """ 

217 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM. 

218 

219 This implements the closed-form solution for the conditional expectation 

220 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution 

221 defined by ``means``, ``covs``, and ``priors``. 

222 

223 Args: 

224 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

225 t (``Array[]``): The time step (scalar). 

226 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. 

227 means (``Array[num_components, data_dim]``): Means of the GMM components. 

228 covs (``Array[num_components, data_dim, data_dim]``): Covariances of the GMM components. 

229 priors (``Array[num_components]``): Mixture weights of the GMM components. 

230 

231 Returns: 

232 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. 

233 """ 

234 num_components, data_dim = means.shape 

235 alpha_t = diffusion_process.alpha(t) 

236 sigma_t = diffusion_process.sigma(t) 

237 

238 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) 

239 covs_t = jax.vmap(lambda cov: alpha_t**2 * cov + sigma_t**2 * jnp.eye(data_dim))( 

240 covs 

241 ) # (num_components, data_dim, data_dim) 

242 

243 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( 

244 means_t 

245 ) # (num_components, data_dim) 

246 covs_t_inv_xbars_t = jax.vmap(lambda cov_t, xbar_t: _lstsq(cov_t, xbar_t))( 

247 covs_t, xbars_t 

248 ) # (num_components, data_dim) 

249 

250 log_likelihoods_unnormalized = jax.vmap( 

251 lambda xbar_t, cov_t, cov_t_inv_xbar_t: -(1 / 2) 

252 * (_logdeth(cov_t) + jnp.sum(xbar_t * cov_t_inv_xbar_t)) 

253 )(xbars_t, covs_t, covs_t_inv_xbars_t) # (num_components,) 

254 

255 log_posterior_unnormalized = ( 

256 jnp.log(priors) + log_likelihoods_unnormalized 

257 ) # (num_components,) 

258 posterior_probs = jax.nn.softmax( 

259 log_posterior_unnormalized, axis=0 

260 ) # (num_components,) sum to 1 

261 

262 posterior_means = jax.vmap( 

263 lambda mean, cov, cov_t_inv_xbar_t: mean + alpha_t * cov @ cov_t_inv_xbar_t 

264 )(means, covs, covs_t_inv_xbars_t) # (num_components, data_dim) 

265 

266 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) 

267 

268 return x0_pred 

269 

270 

271# Generate eps, score, v functions from gmm_x0 

272gmm_eps, gmm_score, gmm_v = create_gmm_vector_field_fns(gmm_x0)