Coverage for src/diffusionlab/distributions/base.py: 100%
32 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Any, Callable, Dict, Tuple
4from jax import Array
6from diffusionlab.dynamics import DiffusionProcess
7from diffusionlab.vector_fields import VectorFieldType
10@dataclass(frozen=True)
11class Distribution:
12 """
13 Base class for all distributions.
15 This class should be subclassed by other distributions when you want to use ground truth
16 scores, denoisers, noise predictors, or velocity estimators.
18 Each distribution implementation provides functions to sample from it and compute various vector fields
19 related to a diffusion process, such as denoising (``x0``), noise prediction (``eps``),
20 velocity estimation (``v``), and score estimation (``score``).
22 Attributes:
23 dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters as JAX arrays.
24 Shapes depend on the specific distribution.
25 dist_hparams (``Dict[str, Any]``): Dictionary containing distribution hyperparameters (non-array values).
26 """
28 dist_params: Dict[str, Array]
29 dist_hparams: Dict[str, Any]
31 def sample(
32 self,
33 key: Array,
34 num_samples: int,
35 ) -> Tuple[Array, Any]:
36 """
37 Sample from the distribution.
39 Args:
40 key (``Array``): The JAX PRNG key to use for sampling.
41 num_samples (``int``): The number of samples to draw.
43 Returns:
44 ``Tuple[Array[num_samples, *data_dims], Any]``: A tuple containing the samples and any additional information.
45 """
46 raise NotImplementedError
48 def get_vector_field(
49 self, vector_field_type: VectorFieldType
50 ) -> Callable[[Array, Array, DiffusionProcess], Array]:
51 """
52 Get the vector field function of a given type associated with this distribution.
54 Args:
55 vector_field_type (``VectorFieldType``): The type of vector field to retrieve (e.g., ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``).
57 Returns:
58 ``Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]``:
59 The requested vector field function. It takes the current state ``x_t`` (``Array[*data_dims]``),
60 time ``t`` (``Array[]``), and the ``diffusion_process`` as input and returns the
61 corresponding vector field value (``Array[*data_dims]``).
62 """
63 match vector_field_type:
64 case VectorFieldType.X0:
65 vector_field = self.x0
66 case VectorFieldType.EPS:
67 vector_field = self.eps
68 case VectorFieldType.V:
69 vector_field = self.v
70 case VectorFieldType.SCORE:
71 vector_field = self.score
72 case _:
73 raise ValueError(
74 f"Vector field type {vector_field_type} is not supported."
75 )
76 return vector_field
78 def score(
79 self,
80 x_t: Array,
81 t: Array,
82 diffusion_process: DiffusionProcess,
83 ) -> Array:
84 """
85 Compute the score function (``∇_x log p_t(x)``) of the distribution at time ``t``,
86 given the noisy state ``x_t`` and the ``diffusion_process``.
88 Args:
89 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``.
90 t (``Array[]``): The time step.
91 diffusion_process (``DiffusionProcess``): The diffusion process definition.
93 Returns:
94 ``Array[*data_dims]``: The score of the distribution at ``(x_t, t)``.
95 """
96 raise NotImplementedError
98 def x0(
99 self,
100 x_t: Array,
101 t: Array,
102 diffusion_process: DiffusionProcess,
103 ) -> Array:
104 """
105 Predict the initial state ``x0`` (denoised sample) from the noisy state ``x_t`` at time ``t``,
106 given the ``diffusion_process``.
108 Args:
109 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``.
110 t (``Array[]``): The time step.
111 diffusion_process (``DiffusionProcess``): The diffusion process definition.
113 Returns:
114 ``Array[*data_dims]``: The predicted initial state ``x0``.
115 """
116 raise NotImplementedError
118 def eps(
119 self,
120 x_t: Array,
121 t: Array,
122 diffusion_process: DiffusionProcess,
123 ) -> Array:
124 """
125 Predict the noise component ``ε`` corresponding to the noisy state ``x_t`` at time ``t``,
126 given the ``diffusion_process``.
128 Args:
129 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``.
130 t (``Array[]``): The time step.
131 diffusion_process (``DiffusionProcess``): The diffusion process definition.
133 Returns:
134 ``Array[*data_dims]``: The predicted noise ``ε``.
135 """
136 raise NotImplementedError
138 def v(
139 self,
140 x_t: Array,
141 t: Array,
142 diffusion_process: DiffusionProcess,
143 ) -> Array:
144 """
145 Compute the velocity field ``v(x_t, t)`` corresponding to the noisy state ``x_t`` at time ``t``,
146 given the ``diffusion_process``.
148 Args:
149 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``.
150 t (``Array[]``): The time step.
151 diffusion_process (``DiffusionProcess``): The diffusion process definition.
153 Returns:
154 ``Array[*data_dims]``: The computed velocity field ``v``.
155 """
156 raise NotImplementedError