Coverage for src/diffusionlab/distributions/gmm/utils.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from jax import Array, numpy as jnp 

2from typing import cast, Callable, Tuple 

3from diffusionlab.dynamics import DiffusionProcess 

4from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type 

5 

6 

7def _logdeth(cov: Array) -> Array: 

8 """ 

9 Computes the log determinant of a positive semi-definite (PSD) matrix. 

10 

11 Uses ``eigh`` for numerical stability with symmetric matrices like covariance matrices. 

12 

13 Args: 

14 cov (``Array[dim, dim]``): The input PSD matrix (e.g., a covariance matrix). 

15 

16 Returns: 

17 ``Array[]``: The log determinant of the matrix (scalar). 

18 """ 

19 eigvals = jnp.linalg.eigvalsh(cov) 

20 return jnp.sum(jnp.log(eigvals)) 

21 

22 

23def _lstsq(A: Array, y: Array) -> Array: 

24 """ 

25 Solves the linear system Ax = y using least squares. 

26 

27 Handles potential conditioning issues by setting rcond based on machine epsilon. 

28 Equivalent to computing A^+ y where A^+ is the Moore-Penrose pseudoinverse. 

29 

30 Args: 

31 A (``Array[out_dim, in_dim]``): The coefficient matrix. 

32 y (``Array[out_dim]``): The dependent variable values. 

33 

34 Returns: 

35 ``Array[in_dim]``: The least-squares solution ``x``. 

36 """ 

37 eps = cast(float, jnp.finfo(A.dtype).eps) 

38 x = jnp.linalg.lstsq(A, y, rcond=eps)[0] 

39 return x 

40 

41 

42def create_gmm_vector_field_fns( 

43 x0_fn: Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], 

44) -> Tuple[ 

45 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], 

46 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], 

47 Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], 

48]: 

49 """ 

50 Factory to create eps, score, and v functions from a given x0 function. 

51 

52 Args: 

53 x0_fn: The specific x0 calculation function (e.g., ``gmm_x0``, ``iso_gmm_x0``). 

54 It must accept ``(x_t, t, diffusion_process, means, specific_cov, priors)``. 

55 

56 Returns: 

57 ``Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]``: 

58 A tuple containing the generated ``(eps_fn, score_fn, v_fn)``. 

59 These functions will have the same signature as ``x0_fn``, accepting 

60 ``(x_t, t, diffusion_process, means, specific_cov, priors)``. 

61 """ 

62 

63 def common_wrapper( 

64 x_t: Array, 

65 t: Array, 

66 diffusion_process: DiffusionProcess, 

67 means: Array, 

68 specific_cov: Array, 

69 priors: Array, 

70 target_type: VectorFieldType, 

71 ) -> Array: 

72 """Internal helper to perform the conversion.""" 

73 x0_x_t = x0_fn(x_t, t, diffusion_process, means, specific_cov, priors) 

74 alpha_t = diffusion_process.alpha(t) 

75 sigma_t = diffusion_process.sigma(t) 

76 alpha_prime_t = diffusion_process.alpha_prime(t) 

77 sigma_prime_t = diffusion_process.sigma_prime(t) 

78 return convert_vector_field_type( 

79 x_t, 

80 x0_x_t, 

81 alpha_t, 

82 sigma_t, 

83 alpha_prime_t, 

84 sigma_prime_t, 

85 VectorFieldType.X0, 

86 target_type, 

87 ) 

88 

89 def eps_fn( 

90 x_t: Array, 

91 t: Array, 

92 diffusion_process: DiffusionProcess, 

93 means: Array, 

94 specific_cov: Array, 

95 priors: Array, 

96 ) -> Array: 

97 """Computes the noise prediction field ε based on the provided x0 function.""" 

98 return common_wrapper( 

99 x_t, t, diffusion_process, means, specific_cov, priors, VectorFieldType.EPS 

100 ) 

101 

102 def score_fn( 

103 x_t: Array, 

104 t: Array, 

105 diffusion_process: DiffusionProcess, 

106 means: Array, 

107 specific_cov: Array, 

108 priors: Array, 

109 ) -> Array: 

110 """Computes the score field based on the provided x0 function.""" 

111 return common_wrapper( 

112 x_t, 

113 t, 

114 diffusion_process, 

115 means, 

116 specific_cov, 

117 priors, 

118 VectorFieldType.SCORE, 

119 ) 

120 

121 def v_fn( 

122 x_t: Array, 

123 t: Array, 

124 diffusion_process: DiffusionProcess, 

125 means: Array, 

126 specific_cov: Array, 

127 priors: Array, 

128 ) -> Array: 

129 """Computes the velocity field v based on the provided x0 function.""" 

130 return common_wrapper( 

131 x_t, t, diffusion_process, means, specific_cov, priors, VectorFieldType.V 

132 ) 

133 

134 # Add base docstrings - specific details might be lost compared to original funcs 

135 base_doc = f"Computes the { } field based on {x0_fn.__name__} by converting the x0 prediction.\n\n Args:\n x_t (Array[data_dim]): The noisy state tensor at time `t`.\n t (Array[]): The time step (scalar).\n diffusion_process (DiffusionProcess): Provides diffusion coefficients and derivatives.\n means (Array[num_components, data_dim]): GMM component means.\n specific_cov: GMM component specific covariance representation (covs, factors, variances, or variance).\n priors (Array[num_components]): GMM component mixture weights.\n\n Returns:\n Array[data_dim]: The corresponding vector field evaluated at `x_t` and `t`." 

136 

137 eps_fn.__doc__ = base_doc.format("noise prediction ε") 

138 score_fn.__doc__ = base_doc.format("score") 

139 v_fn.__doc__ = base_doc.format("velocity v") 

140 

141 return eps_fn, score_fn, v_fn