Coverage for src/diffusionlab/vector_fields.py: 100%
42 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1import enum
3from jax import Array
6class VectorFieldType(enum.Enum):
7 """
8 Enum representing the type of a vector field.
9 A vector field is a function that takes in ``x_t`` (``Array[*data_dims]``) and ``t`` (``Array[]``)
10 and returns a vector of the same shape as ``x_t`` (``Array[*data_dims]``).
12 DiffusionLab supports the following vector field types:
14 - ``VectorFieldType.SCORE``: The score of the distribution.
15 - ``VectorFieldType.X0``: The denoised state.
16 - ``VectorFieldType.EPS``: The noise.
17 - ``VectorFieldType.V``: The velocity field.
18 """
20 SCORE = enum.auto()
21 X0 = enum.auto()
22 EPS = enum.auto()
23 V = enum.auto()
26def convert_vector_field_type(
27 x: Array,
28 f_x: Array,
29 alpha: Array,
30 sigma: Array,
31 alpha_prime: Array,
32 sigma_prime: Array,
33 in_type: VectorFieldType,
34 out_type: VectorFieldType,
35) -> Array:
36 """
37 Converts the output of a vector field from one type to another.
39 Arguments:
40 x (``Array[*data_dims]``): The input tensor.
41 f_x (``Array[*data_dims]``): The output of the vector field f evaluated at x.
42 alpha (``Array[]``): A scalar representing the scale parameter.
43 sigma (``Array[]``): A scalar representing the noise level parameter.
44 alpha_prime (``Array[]``): A scalar representing the scale derivative parameter.
45 sigma_prime (``Array[]``): A scalar representing the noise level derivative parameter.
46 in_type (``VectorFieldType``): The type of the input vector field (e.g. ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``).
47 out_type (``VectorFieldType``): The type of the output vector field.
49 Returns:
50 ``Array[*data_dims]``: The converted output of the vector field
51 """
52 """
53 Derivation:
54 ----------------------------
55 Define certain quantities:
56 alpha_r = alpha' / alpha
57 sigma_r = sigma' / sigma
58 diff_r = sigma_r - alpha_r
59 and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0.
60 Under the data model
61 (1) x := alpha * x0 + sigma * eps
62 it holds that
63 (2) x = alpha * E[x0 | x] + sigma * E[eps | x]
64 Therefore
65 (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha
66 (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma
67 Furthermore, from (1) it holds that
68 (5) v := x' = alpha' * x0 + sigma' * eps,
69 or in particular
70 (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x]
71 Using (3), (4), (6) it holds
72 (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x]
73 => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x]
74 => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x]
75 => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x]
76 (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r)
77 and, similarly,
78 (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x])
79 => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x]
80 => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x]
81 => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x]
82 (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r)
83 To connect the score function to the other types, we use Tweedie's formula:
84 (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma).
85 Therefore, from (11):
86 (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha
87 From (12):
88 (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2
89 From (11) and (4):
90 (14) E[eps | x] = -sigma * score(x, alpha, sigma)
91 From (14):
92 (15) score(x, alpha, sigma) = -E[eps | x] / sigma
93 From (14) and (7):
94 (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma)
95 From (16):
96 (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r)
97 """
98 alpha_ratio = alpha_prime / alpha
99 sigma_ratio = sigma_prime / sigma
100 ratio_diff = sigma_ratio - alpha_ratio
101 converted_fx = f_x
103 if in_type == VectorFieldType.SCORE:
104 if out_type == VectorFieldType.X0:
105 converted_fx = (x + sigma**2 * f_x) / alpha # From equation (12)
106 elif out_type == VectorFieldType.EPS:
107 converted_fx = -sigma * f_x # From equation (14)
108 elif out_type == VectorFieldType.V:
109 converted_fx = (
110 alpha_ratio * x - sigma**2 * ratio_diff * f_x
111 ) # From equation (16)
113 elif in_type == VectorFieldType.X0:
114 if out_type == VectorFieldType.SCORE:
115 converted_fx = (alpha * f_x - x) / sigma**2 # From equation (13)
116 elif out_type == VectorFieldType.EPS:
117 converted_fx = (x - alpha * f_x) / sigma # From equation (4)
118 elif out_type == VectorFieldType.V:
119 converted_fx = (
120 sigma_ratio * x - alpha * ratio_diff * f_x
121 ) # From equation (9)
123 elif in_type == VectorFieldType.EPS:
124 if out_type == VectorFieldType.SCORE:
125 converted_fx = -f_x / sigma # From equation (15)
126 elif out_type == VectorFieldType.X0:
127 converted_fx = (x - sigma * f_x) / alpha # From equation (3)
128 elif out_type == VectorFieldType.V:
129 converted_fx = (
130 alpha_ratio * x + sigma * ratio_diff * f_x
131 ) # From equation (7)
133 elif in_type == VectorFieldType.V:
134 if out_type == VectorFieldType.SCORE:
135 converted_fx = (alpha_ratio * x - f_x) / (
136 sigma**2 * ratio_diff
137 ) # From equation (17)
138 elif out_type == VectorFieldType.X0:
139 converted_fx = (sigma_ratio * x - f_x) / (
140 alpha * ratio_diff
141 ) # From equation (10)
142 elif out_type == VectorFieldType.EPS:
143 converted_fx = (f_x - alpha_ratio * x) / (
144 sigma * ratio_diff
145 ) # From equation (8)
147 return converted_fx