Coverage for src/diffusionlab/losses.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from typing import Callable 

2from dataclasses import dataclass, field 

3import jax 

4from jax import numpy as jnp, Array 

5 

6from diffusionlab.dynamics import DiffusionProcess 

7from diffusionlab.vector_fields import VectorFieldType 

8 

9 

10@dataclass 

11class DiffusionLoss: 

12 """ 

13 Loss function for training diffusion models. 

14 

15 This dataclass implements various loss functions for diffusion models based on the specified 

16 target type. The loss is computed as the mean squared error between the model's prediction 

17 and the target, which depends on the chosen vector field type. 

18 

19 The loss supports different target types: 

20 

21 - ``VectorFieldType.X0``: Learn to predict the original clean data x_0 

22 - ``VectorFieldType.EPS``: Learn to predict the noise component eps 

23 - ``VectorFieldType.V``: Learn to predict the velocity field v 

24 - ``VectorFieldType.SCORE``: Not directly supported (raises ValueError) 

25 

26 Attributes: 

27 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics 

28 vector_field_type (``VectorFieldType``): The type of target to learn to estimate via minimizing the loss function. 

29 num_noise_draws_per_sample (``int``): The number of noise draws per sample to use for the batchwise loss. 

30 target (``Callable[[Array, Array, Array, Array, Array], Array]``): Function that computes the target based on the specified target_type. 

31 

32 Signature: ``(x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]`` 

33 """ 

34 

35 diffusion_process: DiffusionProcess 

36 vector_field_type: VectorFieldType 

37 num_noise_draws_per_sample: int 

38 target: Callable[[Array, Array, Array, Array, Array], Array] = field(init=False) 

39 

40 def __post_init__(self): 

41 match self.vector_field_type: 

42 case VectorFieldType.X0: 

43 

44 def target( 

45 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array 

46 ) -> Array: 

47 return x_0 

48 

49 case VectorFieldType.EPS: 

50 

51 def target( 

52 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array 

53 ) -> Array: 

54 return eps 

55 

56 case VectorFieldType.V: 

57 

58 def target( 

59 x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array 

60 ) -> Array: 

61 return ( 

62 self.diffusion_process.alpha_prime(t) * x_0 

63 + self.diffusion_process.sigma_prime(t) * eps 

64 ) 

65 

66 case VectorFieldType.SCORE: 

67 raise ValueError( 

68 "Direct score matching is not supported due to lack of a known target function, and other ways (like Hutchinson's trace estimator) are very high variance." 

69 ) 

70 

71 case _: 

72 raise ValueError(f"Invalid target type: {self.vector_field_type}") 

73 

74 self.target = target 

75 

76 def prediction_loss( 

77 self, x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array 

78 ) -> Array: 

79 """ 

80 Compute the loss given a prediction and inputs/targets. 

81 

82 This method calculates the mean squared error between the model's prediction (``f_x_t``) 

83 and the target value determined by the target_type (``self.target``). 

84 

85 Args: 

86 x_t (``Array[*data_dims]``): The noised data at time ``t``. 

87 f_x_t (``Array[*data_dims]``): The model's prediction at time ``t``. 

88 x_0 (``Array[*data_dims]``): The original clean data. 

89 eps (``Array[*data_dims]``): The noise used to generate ``x_t``. 

90 t (``Array[]``): The scalar time parameter. 

91 

92 Returns: 

93 ``Array[]``: The scalar loss value for the given sample. 

94 """ 

95 squared_residuals = (f_x_t - self.target(x_t, f_x_t, x_0, eps, t)) ** 2 

96 samplewise_loss = jnp.sum(squared_residuals) 

97 return samplewise_loss 

98 

99 def loss( 

100 self, 

101 key: Array, 

102 vector_field: Callable[[Array, Array], Array], 

103 x_0: Array, 

104 t: Array, 

105 ) -> Array: 

106 """ 

107 Compute the average loss over multiple noise draws for a single data point and time. 

108 

109 This method estimates the expected loss at a given time ``t`` for a clean data sample ``x_0``. 

110 It does this by drawing ``num_noise_draws_per_sample`` noise vectors (``eps``), generating 

111 the corresponding noisy samples ``x_t`` using the ``diffusion_process``, predicting the 

112 target quantity ``f_x_t`` using the provided ``vector_field`` (vmapped internally), and then calculating the 

113 ``prediction_loss`` for each noise sample. The final loss is the average over these samples. 

114 

115 Args: 

116 key (``Array``): The PRNG key for noise generation. 

117 vector_field (``Callable[[Array, Array], Array]``): The vector field function that takes 

118 a single noisy data sample ``x_t`` and its corresponding time ``t``, and returns the model's prediction ``f_x_t``. 

119 This function will be vmapped internally over the batch dimension created by ``num_noise_draws_per_sample``. 

120 

121 Signature: ``(x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims]``. 

122 

123 x_0 (``Array[*data_dims]``): The original clean data sample. 

124 t (``Array[]``): The scalar time parameter. 

125 

126 Returns: 

127 ``Array[]``: The scalar loss value, averaged over ``num_noise_draws_per_sample`` noise instances. 

128 """ 

129 x_0_batch = x_0[None, ...].repeat(self.num_noise_draws_per_sample, axis=0) 

130 t_batch = t[None].repeat(self.num_noise_draws_per_sample, axis=0) 

131 eps_batch = jax.random.normal(key, x_0_batch.shape) 

132 

133 batch_diffusion_forward = jax.vmap( 

134 self.diffusion_process.forward, in_axes=(0, 0, 0) 

135 ) 

136 x_t_batch = batch_diffusion_forward(x_0_batch, t_batch, eps_batch) 

137 

138 batch_vector_field = jax.vmap(vector_field, in_axes=(0, 0)) 

139 f_x_t_batch = batch_vector_field(x_t_batch, t_batch) 

140 

141 batch_prediction_loss = jax.vmap(self.prediction_loss, in_axes=(0, 0, 0, 0, 0)) 

142 losses = batch_prediction_loss( 

143 x_t_batch, f_x_t_batch, x_0_batch, eps_batch, t_batch 

144 ) 

145 

146 loss_value = jnp.mean(losses) 

147 return loss_value