docs for muutils v0.8.10
View Source on GitHub

muutils.math.bins


 1from __future__ import annotations
 2
 3from dataclasses import dataclass
 4from functools import cached_property
 5from typing import Literal
 6
 7import numpy as np
 8from jaxtyping import Float
 9
10
11@dataclass(frozen=True)
12class Bins:
13    n_bins: int = 32
14    start: float = 0
15    stop: float = 1.0
16    scale: Literal["lin", "log"] = "log"
17
18    _log_min: float = 1e-3
19    _zero_in_small_start_log: bool = True
20
21    @cached_property
22    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
23        if self.scale == "lin":
24            return np.linspace(self.start, self.stop, self.n_bins + 1)
25        elif self.scale == "log":
26            if self.start < 0:
27                raise ValueError(
28                    f"start must be positive for log scale, got {self.start}"
29                )
30            if self.start == 0:
31                return np.concatenate(
32                    [
33                        np.array([0]),
34                        np.logspace(
35                            np.log10(self._log_min), np.log10(self.stop), self.n_bins
36                        ),
37                    ]
38                )
39            elif self.start < self._log_min and self._zero_in_small_start_log:
40                return np.concatenate(
41                    [
42                        np.array([0]),
43                        np.logspace(
44                            np.log10(self.start), np.log10(self.stop), self.n_bins
45                        ),
46                    ]
47                )
48            else:
49                return np.logspace(
50                    np.log10(self.start), np.log10(self.stop), self.n_bins + 1
51                )
52        else:
53            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
54
55    @cached_property
56    def centers(self) -> Float[np.ndarray, "n_bins"]:
57        return (self.edges[:-1] + self.edges[1:]) / 2
58
59    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
60        return Bins(
61            n_bins=n_bins,
62            start=self.start,
63            stop=self.stop,
64            scale=self.scale,
65            _log_min=self._log_min,
66            _zero_in_small_start_log=self._zero_in_small_start_log,
67        )

@dataclass(frozen=True)
class Bins:
12@dataclass(frozen=True)
13class Bins:
14    n_bins: int = 32
15    start: float = 0
16    stop: float = 1.0
17    scale: Literal["lin", "log"] = "log"
18
19    _log_min: float = 1e-3
20    _zero_in_small_start_log: bool = True
21
22    @cached_property
23    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
24        if self.scale == "lin":
25            return np.linspace(self.start, self.stop, self.n_bins + 1)
26        elif self.scale == "log":
27            if self.start < 0:
28                raise ValueError(
29                    f"start must be positive for log scale, got {self.start}"
30                )
31            if self.start == 0:
32                return np.concatenate(
33                    [
34                        np.array([0]),
35                        np.logspace(
36                            np.log10(self._log_min), np.log10(self.stop), self.n_bins
37                        ),
38                    ]
39                )
40            elif self.start < self._log_min and self._zero_in_small_start_log:
41                return np.concatenate(
42                    [
43                        np.array([0]),
44                        np.logspace(
45                            np.log10(self.start), np.log10(self.stop), self.n_bins
46                        ),
47                    ]
48                )
49            else:
50                return np.logspace(
51                    np.log10(self.start), np.log10(self.stop), self.n_bins + 1
52                )
53        else:
54            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
55
56    @cached_property
57    def centers(self) -> Float[np.ndarray, "n_bins"]:
58        return (self.edges[:-1] + self.edges[1:]) / 2
59
60    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
61        return Bins(
62            n_bins=n_bins,
63            start=self.start,
64            stop=self.stop,
65            scale=self.scale,
66            _log_min=self._log_min,
67            _zero_in_small_start_log=self._zero_in_small_start_log,
68        )
Bins( n_bins: int = 32, start: float = 0, stop: float = 1.0, scale: Literal['lin', 'log'] = 'log', _log_min: float = 0.001, _zero_in_small_start_log: bool = True)
n_bins: int = 32
start: float = 0
stop: float = 1.0
scale: Literal['lin', 'log'] = 'log'
edges: jaxtyping.Float[ndarray, 'n_bins+1']
22    @cached_property
23    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
24        if self.scale == "lin":
25            return np.linspace(self.start, self.stop, self.n_bins + 1)
26        elif self.scale == "log":
27            if self.start < 0:
28                raise ValueError(
29                    f"start must be positive for log scale, got {self.start}"
30                )
31            if self.start == 0:
32                return np.concatenate(
33                    [
34                        np.array([0]),
35                        np.logspace(
36                            np.log10(self._log_min), np.log10(self.stop), self.n_bins
37                        ),
38                    ]
39                )
40            elif self.start < self._log_min and self._zero_in_small_start_log:
41                return np.concatenate(
42                    [
43                        np.array([0]),
44                        np.logspace(
45                            np.log10(self.start), np.log10(self.stop), self.n_bins
46                        ),
47                    ]
48                )
49            else:
50                return np.logspace(
51                    np.log10(self.start), np.log10(self.stop), self.n_bins + 1
52                )
53        else:
54            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
centers: jaxtyping.Float[ndarray, 'n_bins']
56    @cached_property
57    def centers(self) -> Float[np.ndarray, "n_bins"]:
58        return (self.edges[:-1] + self.edges[1:]) / 2
def changed_n_bins_copy(self, n_bins: int) -> Bins:
60    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
61        return Bins(
62            n_bins=n_bins,
63            start=self.start,
64            stop=self.stop,
65            scale=self.scale,
66            _log_min=self._log_min,
67            _zero_in_small_start_log=self._zero_in_small_start_log,
68        )