Coverage for src/diffusionlab/schedulers.py: 100%
20 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Any
3from jax import Array, numpy as jnp
6@dataclass(frozen=True)
7class Scheduler:
8 """
9 Base class for time step schedulers used in diffusion, denoising, and sampling.
11 Allows for extensible scheduler implementations where subclasses can define
12 their own initialization and time step generation parameters via **kwargs.
13 """
15 def get_ts(self, **ts_hparams: Any) -> Array:
16 """
17 Generate the sequence of time steps.
19 This is an abstract method that must be implemented by subclasses.
20 Subclasses should define the specific keyword arguments they expect
21 within ``**ts_hparams``.
23 Args:
24 **ts_hparams (``Dict[str, Any]``): Keyword arguments containing parameters for generating time steps.
26 Returns:
27 ``Array``: A tensor containing the sequence of time steps in descending order.
29 Raises:
30 NotImplementedError: If the subclass does not implement this method.
31 KeyError: If a required parameter is missing in ``**ts_hparams`` (in subclass).
32 """
33 raise NotImplementedError
36@dataclass(frozen=True)
37class UniformScheduler(Scheduler):
38 """
39 A scheduler that generates uniformly spaced time steps.
41 Requires ``t_min``, ``t_max``, and ``num_steps`` to be passed
42 to the ``get_ts`` method via keyword arguments. The number of points generated
43 will be ``num_steps + 1``.
44 """
46 def get_ts(self, **ts_hparams: Any) -> Array:
47 """
48 Generate uniformly spaced time steps.
50 Args:
51 **ts_hparams (``Dict[str, Any]``): Keyword arguments must contain
53 - ``t_min`` (``float``): The minimum time value, typically close to 0.
54 - ``t_max`` (``float``): The maximum time value, typically close to 1.
55 - ``num_steps`` (``int``): The number of diffusion steps. The function will generate ``num_steps + 1`` time points.
57 Returns:
58 ``Array[num_steps+1]``: A JAX array containing uniformly spaced time steps
59 in descending order (from ``t_max`` to ``t_min``).
61 Raises:
62 KeyError: If ``t_min``, ``t_max``, or ``num_steps`` is not found in ``ts_hparams``.
63 AssertionError: If ``t_min``/``t_max`` constraints are violated or ``num_steps`` < 1.
64 """
65 try:
66 t_min = ts_hparams["t_min"]
67 t_max = ts_hparams["t_max"]
68 num_steps = ts_hparams["num_steps"]
69 except KeyError as e:
70 raise KeyError(
71 f"Missing required parameter for UniformScheduler.get_ts: {e}"
72 ) from e
74 assert 0 <= t_min <= t_max <= 1, "t_min and t_max must be in the range [0, 1]"
75 assert num_steps >= 1, "num_steps must be at least 1"
77 ts = jnp.linspace(t_min, t_max, num_steps + 1)[::-1]
78 return ts