Coverage for src/diffusionlab/distributions/gmm/utils.py: 100%
30 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from jax import Array, numpy as jnp
2from typing import cast, Callable, Tuple
3from diffusionlab.dynamics import DiffusionProcess
4from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type
7def _logdeth(cov: Array) -> Array:
8 """
9 Computes the log determinant of a positive semi-definite (PSD) matrix.
11 Uses ``eigh`` for numerical stability with symmetric matrices like covariance matrices.
13 Args:
14 cov (``Array[dim, dim]``): The input PSD matrix (e.g., a covariance matrix).
16 Returns:
17 ``Array[]``: The log determinant of the matrix (scalar).
18 """
19 eigvals = jnp.linalg.eigvalsh(cov)
20 return jnp.sum(jnp.log(eigvals))
23def _lstsq(A: Array, y: Array) -> Array:
24 """
25 Solves the linear system Ax = y using least squares.
27 Handles potential conditioning issues by setting rcond based on machine epsilon.
28 Equivalent to computing A^+ y where A^+ is the Moore-Penrose pseudoinverse.
30 Args:
31 A (``Array[out_dim, in_dim]``): The coefficient matrix.
32 y (``Array[out_dim]``): The dependent variable values.
34 Returns:
35 ``Array[in_dim]``: The least-squares solution ``x``.
36 """
37 eps = cast(float, jnp.finfo(A.dtype).eps)
38 x = jnp.linalg.lstsq(A, y, rcond=eps)[0]
39 return x
42def create_gmm_vector_field_fns(
43 x0_fn: Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array],
44) -> Tuple[
45 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array],
46 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array],
47 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array],
48]:
49 """
50 Factory to create eps, score, and v functions from a given x0 function.
52 Args:
53 x0_fn: The specific x0 calculation function (e.g., ``gmm_x0``, ``iso_gmm_x0``).
54 It must accept ``(x_t, t, diffusion_process, means, specific_cov, priors)``.
56 Returns:
57 ``Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]``:
58 A tuple containing the generated ``(eps_fn, score_fn, v_fn)``.
59 These functions will have the same signature as ``x0_fn``, accepting
60 ``(x_t, t, diffusion_process, means, specific_cov, priors)``.
61 """
63 def common_wrapper(
64 x_t: Array,
65 t: Array,
66 diffusion_process: DiffusionProcess,
67 means: Array,
68 specific_cov: Array,
69 priors: Array,
70 target_type: VectorFieldType,
71 ) -> Array:
72 """Internal helper to perform the conversion."""
73 x0_x_t = x0_fn(x_t, t, diffusion_process, means, specific_cov, priors)
74 alpha_t = diffusion_process.alpha(t)
75 sigma_t = diffusion_process.sigma(t)
76 alpha_prime_t = diffusion_process.alpha_prime(t)
77 sigma_prime_t = diffusion_process.sigma_prime(t)
78 return convert_vector_field_type(
79 x_t,
80 x0_x_t,
81 alpha_t,
82 sigma_t,
83 alpha_prime_t,
84 sigma_prime_t,
85 VectorFieldType.X0,
86 target_type,
87 )
89 def eps_fn(
90 x_t: Array,
91 t: Array,
92 diffusion_process: DiffusionProcess,
93 means: Array,
94 specific_cov: Array,
95 priors: Array,
96 ) -> Array:
97 """Computes the noise prediction field ε based on the provided x0 function."""
98 return common_wrapper(
99 x_t, t, diffusion_process, means, specific_cov, priors, VectorFieldType.EPS
100 )
102 def score_fn(
103 x_t: Array,
104 t: Array,
105 diffusion_process: DiffusionProcess,
106 means: Array,
107 specific_cov: Array,
108 priors: Array,
109 ) -> Array:
110 """Computes the score field based on the provided x0 function."""
111 return common_wrapper(
112 x_t,
113 t,
114 diffusion_process,
115 means,
116 specific_cov,
117 priors,
118 VectorFieldType.SCORE,
119 )
121 def v_fn(
122 x_t: Array,
123 t: Array,
124 diffusion_process: DiffusionProcess,
125 means: Array,
126 specific_cov: Array,
127 priors: Array,
128 ) -> Array:
129 """Computes the velocity field v based on the provided x0 function."""
130 return common_wrapper(
131 x_t, t, diffusion_process, means, specific_cov, priors, VectorFieldType.V
132 )
134 # Add base docstrings - specific details might be lost compared to original funcs
135 base_doc = f"Computes the { } field based on {x0_fn.__name__} by converting the x0 prediction.\n\n Args:\n x_t (Array[data_dim]): The noisy state tensor at time `t`.\n t (Array[]): The time step (scalar).\n diffusion_process (DiffusionProcess): Provides diffusion coefficients and derivatives.\n means (Array[num_components, data_dim]): GMM component means.\n specific_cov: GMM component specific covariance representation (covs, factors, variances, or variance).\n priors (Array[num_components]): GMM component mixture weights.\n\n Returns:\n Array[data_dim]: The corresponding vector field evaluated at `x_t` and `t`."
137 eps_fn.__doc__ = base_doc.format("noise prediction ε")
138 score_fn.__doc__ = base_doc.format("score")
139 v_fn.__doc__ = base_doc.format("velocity v")
141 return eps_fn, score_fn, v_fn