Coverage for src/diffusionlab/vector_fields.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1import enum 

2 

3from jax import Array 

4 

5 

6class VectorFieldType(enum.Enum): 

7 """ 

8 Enum representing the type of a vector field. 

9 A vector field is a function that takes in ``x_t`` (``Array[*data_dims]``) and ``t`` (``Array[]``) 

10 and returns a vector of the same shape as ``x_t`` (``Array[*data_dims]``). 

11 

12 DiffusionLab supports the following vector field types: 

13 

14 - ``VectorFieldType.SCORE``: The score of the distribution. 

15 - ``VectorFieldType.X0``: The denoised state. 

16 - ``VectorFieldType.EPS``: The noise. 

17 - ``VectorFieldType.V``: The velocity field. 

18 """ 

19 

20 SCORE = enum.auto() 

21 X0 = enum.auto() 

22 EPS = enum.auto() 

23 V = enum.auto() 

24 

25 

26def convert_vector_field_type( 

27 x: Array, 

28 f_x: Array, 

29 alpha: Array, 

30 sigma: Array, 

31 alpha_prime: Array, 

32 sigma_prime: Array, 

33 in_type: VectorFieldType, 

34 out_type: VectorFieldType, 

35) -> Array: 

36 """ 

37 Converts the output of a vector field from one type to another. 

38 

39 Arguments: 

40 x (``Array[*data_dims]``): The input tensor. 

41 f_x (``Array[*data_dims]``): The output of the vector field f evaluated at x. 

42 alpha (``Array[]``): A scalar representing the scale parameter. 

43 sigma (``Array[]``): A scalar representing the noise level parameter. 

44 alpha_prime (``Array[]``): A scalar representing the scale derivative parameter. 

45 sigma_prime (``Array[]``): A scalar representing the noise level derivative parameter. 

46 in_type (``VectorFieldType``): The type of the input vector field (e.g. ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). 

47 out_type (``VectorFieldType``): The type of the output vector field. 

48 

49 Returns: 

50 ``Array[*data_dims]``: The converted output of the vector field 

51 """ 

52 """ 

53 Derivation: 

54 ---------------------------- 

55 Define certain quantities: 

56 alpha_r = alpha' / alpha 

57 sigma_r = sigma' / sigma 

58 diff_r = sigma_r - alpha_r 

59 and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0.  

60 Under the data model  

61 (1) x := alpha * x0 + sigma * eps 

62 it holds that  

63 (2) x = alpha * E[x0 | x] + sigma * E[eps | x] 

64 Therefore  

65 (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha 

66 (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma 

67 Furthermore, from (1) it holds that 

68 (5) v := x' = alpha' * x0 + sigma' * eps, 

69 or in particular 

70 (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x] 

71 Using (3), (4), (6) it holds  

72 (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x]  

73 => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x] 

74 => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x] 

75 => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x] 

76 (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r) 

77 and, similarly, 

78 (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x])  

79 => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x] 

80 => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x] 

81 => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x] 

82 (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r) 

83 To connect the score function to the other types, we use Tweedie's formula: 

84 (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma). 

85 Therefore, from (11): 

86 (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha 

87 From (12): 

88 (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2 

89 From (11) and (4): 

90 (14) E[eps | x] = -sigma * score(x, alpha, sigma) 

91 From (14): 

92 (15) score(x, alpha, sigma) = -E[eps | x] / sigma 

93 From (14) and (7): 

94 (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma) 

95 From (16): 

96 (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r) 

97 """ 

98 alpha_ratio = alpha_prime / alpha 

99 sigma_ratio = sigma_prime / sigma 

100 ratio_diff = sigma_ratio - alpha_ratio 

101 converted_fx = f_x 

102 

103 if in_type == VectorFieldType.SCORE: 

104 if out_type == VectorFieldType.X0: 

105 converted_fx = (x + sigma**2 * f_x) / alpha # From equation (12) 

106 elif out_type == VectorFieldType.EPS: 

107 converted_fx = -sigma * f_x # From equation (14) 

108 elif out_type == VectorFieldType.V: 

109 converted_fx = ( 

110 alpha_ratio * x - sigma**2 * ratio_diff * f_x 

111 ) # From equation (16) 

112 

113 elif in_type == VectorFieldType.X0: 

114 if out_type == VectorFieldType.SCORE: 

115 converted_fx = (alpha * f_x - x) / sigma**2 # From equation (13) 

116 elif out_type == VectorFieldType.EPS: 

117 converted_fx = (x - alpha * f_x) / sigma # From equation (4) 

118 elif out_type == VectorFieldType.V: 

119 converted_fx = ( 

120 sigma_ratio * x - alpha * ratio_diff * f_x 

121 ) # From equation (9) 

122 

123 elif in_type == VectorFieldType.EPS: 

124 if out_type == VectorFieldType.SCORE: 

125 converted_fx = -f_x / sigma # From equation (15) 

126 elif out_type == VectorFieldType.X0: 

127 converted_fx = (x - sigma * f_x) / alpha # From equation (3) 

128 elif out_type == VectorFieldType.V: 

129 converted_fx = ( 

130 alpha_ratio * x + sigma * ratio_diff * f_x 

131 ) # From equation (7) 

132 

133 elif in_type == VectorFieldType.V: 

134 if out_type == VectorFieldType.SCORE: 

135 converted_fx = (alpha_ratio * x - f_x) / ( 

136 sigma**2 * ratio_diff 

137 ) # From equation (17) 

138 elif out_type == VectorFieldType.X0: 

139 converted_fx = (sigma_ratio * x - f_x) / ( 

140 alpha * ratio_diff 

141 ) # From equation (10) 

142 elif out_type == VectorFieldType.EPS: 

143 converted_fx = (f_x - alpha_ratio * x) / ( 

144 sigma * ratio_diff 

145 ) # From equation (8) 

146 

147 return converted_fx