Coverage for src/diffusionlab/schedulers.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Any 

3from jax import Array, numpy as jnp 

4 

5 

6@dataclass(frozen=True) 

7class Scheduler: 

8 """ 

9 Base class for time step schedulers used in diffusion, denoising, and sampling. 

10 

11 Allows for extensible scheduler implementations where subclasses can define 

12 their own initialization and time step generation parameters via **kwargs. 

13 """ 

14 

15 def get_ts(self, **ts_hparams: Any) -> Array: 

16 """ 

17 Generate the sequence of time steps. 

18 

19 This is an abstract method that must be implemented by subclasses. 

20 Subclasses should define the specific keyword arguments they expect 

21 within ``**ts_hparams``. 

22 

23 Args: 

24 **ts_hparams (``Dict[str, Any]``): Keyword arguments containing parameters for generating time steps. 

25 

26 Returns: 

27 ``Array``: A tensor containing the sequence of time steps in descending order. 

28 

29 Raises: 

30 NotImplementedError: If the subclass does not implement this method. 

31 KeyError: If a required parameter is missing in ``**ts_hparams`` (in subclass). 

32 """ 

33 raise NotImplementedError 

34 

35 

36@dataclass(frozen=True) 

37class UniformScheduler(Scheduler): 

38 """ 

39 A scheduler that generates uniformly spaced time steps. 

40 

41 Requires ``t_min``, ``t_max``, and ``num_steps`` to be passed 

42 to the ``get_ts`` method via keyword arguments. The number of points generated 

43 will be ``num_steps + 1``. 

44 """ 

45 

46 def get_ts(self, **ts_hparams: Any) -> Array: 

47 """ 

48 Generate uniformly spaced time steps. 

49 

50 Args: 

51 **ts_hparams (``Dict[str, Any]``): Keyword arguments must contain 

52 

53 - ``t_min`` (``float``): The minimum time value, typically close to 0. 

54 - ``t_max`` (``float``): The maximum time value, typically close to 1. 

55 - ``num_steps`` (``int``): The number of diffusion steps. The function will generate ``num_steps + 1`` time points. 

56 

57 Returns: 

58 ``Array[num_steps+1]``: A JAX array containing uniformly spaced time steps 

59 in descending order (from ``t_max`` to ``t_min``). 

60 

61 Raises: 

62 KeyError: If ``t_min``, ``t_max``, or ``num_steps`` is not found in ``ts_hparams``. 

63 AssertionError: If ``t_min``/``t_max`` constraints are violated or ``num_steps`` < 1. 

64 """ 

65 try: 

66 t_min = ts_hparams["t_min"] 

67 t_max = ts_hparams["t_max"] 

68 num_steps = ts_hparams["num_steps"] 

69 except KeyError as e: 

70 raise KeyError( 

71 f"Missing required parameter for UniformScheduler.get_ts: {e}" 

72 ) from e 

73 

74 assert 0 <= t_min <= t_max <= 1, "t_min and t_max must be in the range [0, 1]" 

75 assert num_steps >= 1, "num_steps must be at least 1" 

76 

77 ts = jnp.linspace(t_min, t_max, num_steps + 1)[::-1] 

78 return ts