Coverage for src/diffusionlab/losses.py: 100%
44 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from typing import Callable
2from dataclasses import dataclass, field
3import jax
4from jax import numpy as jnp, Array
6from diffusionlab.dynamics import DiffusionProcess
7from diffusionlab.vector_fields import VectorFieldType
10@dataclass
11class DiffusionLoss:
12 """
13 Loss function for training diffusion models.
15 This dataclass implements various loss functions for diffusion models based on the specified
16 target type. The loss is computed as the mean squared error between the model's prediction
17 and the target, which depends on the chosen vector field type.
19 The loss supports different target types:
21 - ``VectorFieldType.X0``: Learn to predict the original clean data x_0
22 - ``VectorFieldType.EPS``: Learn to predict the noise component eps
23 - ``VectorFieldType.V``: Learn to predict the velocity field v
24 - ``VectorFieldType.SCORE``: Not directly supported (raises ValueError)
26 Attributes:
27 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics
28 vector_field_type (``VectorFieldType``): The type of target to learn to estimate via minimizing the loss function.
29 num_noise_draws_per_sample (``int``): The number of noise draws per sample to use for the batchwise loss.
30 target (``Callable[[Array, Array, Array, Array, Array], Array]``): Function that computes the target based on the specified target_type.
32 Signature: ``(x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]``
33 """
35 diffusion_process: DiffusionProcess
36 vector_field_type: VectorFieldType
37 num_noise_draws_per_sample: int
38 target: Callable[[Array, Array, Array, Array, Array], Array] = field(init=False)
40 def __post_init__(self):
41 match self.vector_field_type:
42 case VectorFieldType.X0:
44 def target(
45 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array
46 ) -> Array:
47 return x_0
49 case VectorFieldType.EPS:
51 def target(
52 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array
53 ) -> Array:
54 return eps
56 case VectorFieldType.V:
58 def target(
59 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array
60 ) -> Array:
61 return (
62 self.diffusion_process.alpha_prime(t) * x_0
63 + self.diffusion_process.sigma_prime(t) * eps
64 )
66 case VectorFieldType.SCORE:
67 raise ValueError(
68 "Direct score matching is not supported due to lack of a known target function, and other ways (like Hutchinson's trace estimator) are very high variance."
69 )
71 case _:
72 raise ValueError(f"Invalid target type: {self.vector_field_type}")
74 self.target = target
76 def prediction_loss(
77 self, x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array
78 ) -> Array:
79 """
80 Compute the loss given a prediction and inputs/targets.
82 This method calculates the mean squared error between the model's prediction (``f_x_t``)
83 and the target value determined by the target_type (``self.target``).
85 Args:
86 x_t (``Array[*data_dims]``): The noised data at time ``t``.
87 f_x_t (``Array[*data_dims]``): The model's prediction at time ``t``.
88 x_0 (``Array[*data_dims]``): The original clean data.
89 eps (``Array[*data_dims]``): The noise used to generate ``x_t``.
90 t (``Array[]``): The scalar time parameter.
92 Returns:
93 ``Array[]``: The scalar loss value for the given sample.
94 """
95 squared_residuals = (f_x_t - self.target(x_t, f_x_t, x_0, eps, t)) ** 2
96 samplewise_loss = jnp.sum(squared_residuals)
97 return samplewise_loss
99 def loss(
100 self,
101 key: Array,
102 vector_field: Callable[[Array, Array], Array],
103 x_0: Array,
104 t: Array,
105 ) -> Array:
106 """
107 Compute the average loss over multiple noise draws for a single data point and time.
109 This method estimates the expected loss at a given time ``t`` for a clean data sample ``x_0``.
110 It does this by drawing ``num_noise_draws_per_sample`` noise vectors (``eps``), generating
111 the corresponding noisy samples ``x_t`` using the ``diffusion_process``, predicting the
112 target quantity ``f_x_t`` using the provided ``vector_field`` (vmapped internally), and then calculating the
113 ``prediction_loss`` for each noise sample. The final loss is the average over these samples.
115 Args:
116 key (``Array``): The PRNG key for noise generation.
117 vector_field (``Callable[[Array, Array], Array]``): The vector field function that takes
118 a single noisy data sample ``x_t`` and its corresponding time ``t``, and returns the model's prediction ``f_x_t``.
119 This function will be vmapped internally over the batch dimension created by ``num_noise_draws_per_sample``.
121 Signature: ``(x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims]``.
123 x_0 (``Array[*data_dims]``): The original clean data sample.
124 t (``Array[]``): The scalar time parameter.
126 Returns:
127 ``Array[]``: The scalar loss value, averaged over ``num_noise_draws_per_sample`` noise instances.
128 """
129 x_0_batch = x_0[None, ...].repeat(self.num_noise_draws_per_sample, axis=0)
130 t_batch = t[None].repeat(self.num_noise_draws_per_sample, axis=0)
131 eps_batch = jax.random.normal(key, x_0_batch.shape)
133 batch_diffusion_forward = jax.vmap(
134 self.diffusion_process.forward, in_axes=(0, 0, 0)
135 )
136 x_t_batch = batch_diffusion_forward(x_0_batch, t_batch, eps_batch)
138 batch_vector_field = jax.vmap(vector_field, in_axes=(0, 0))
139 f_x_t_batch = batch_vector_field(x_t_batch, t_batch)
141 batch_prediction_loss = jax.vmap(self.prediction_loss, in_axes=(0, 0, 0, 0, 0))
142 losses = batch_prediction_loss(
143 x_t_batch, f_x_t_batch, x_0_batch, eps_batch, t_batch
144 )
146 loss_value = jnp.mean(losses)
147 return loss_value