muutils.math.bins
1from __future__ import annotations 2 3from dataclasses import dataclass 4from functools import cached_property 5from typing import Literal 6 7import numpy as np 8from jaxtyping import Float 9 10 11@dataclass(frozen=True) 12class Bins: 13 n_bins: int = 32 14 start: float = 0 15 stop: float = 1.0 16 scale: Literal["lin", "log"] = "log" 17 18 _log_min: float = 1e-3 19 _zero_in_small_start_log: bool = True 20 21 @cached_property 22 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 23 if self.scale == "lin": 24 return np.linspace(self.start, self.stop, self.n_bins + 1) 25 elif self.scale == "log": 26 if self.start < 0: 27 raise ValueError( 28 f"start must be positive for log scale, got {self.start}" 29 ) 30 if self.start == 0: 31 return np.concatenate( 32 [ 33 np.array([0]), 34 np.logspace( 35 np.log10(self._log_min), np.log10(self.stop), self.n_bins 36 ), 37 ] 38 ) 39 elif self.start < self._log_min and self._zero_in_small_start_log: 40 return np.concatenate( 41 [ 42 np.array([0]), 43 np.logspace( 44 np.log10(self.start), np.log10(self.stop), self.n_bins 45 ), 46 ] 47 ) 48 else: 49 return np.logspace( 50 np.log10(self.start), np.log10(self.stop), self.n_bins + 1 51 ) 52 else: 53 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 54 55 @cached_property 56 def centers(self) -> Float[np.ndarray, "n_bins"]: 57 return (self.edges[:-1] + self.edges[1:]) / 2 58 59 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 60 return Bins( 61 n_bins=n_bins, 62 start=self.start, 63 stop=self.stop, 64 scale=self.scale, 65 _log_min=self._log_min, 66 _zero_in_small_start_log=self._zero_in_small_start_log, 67 )
@dataclass(frozen=True)
class
Bins:
12@dataclass(frozen=True) 13class Bins: 14 n_bins: int = 32 15 start: float = 0 16 stop: float = 1.0 17 scale: Literal["lin", "log"] = "log" 18 19 _log_min: float = 1e-3 20 _zero_in_small_start_log: bool = True 21 22 @cached_property 23 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 24 if self.scale == "lin": 25 return np.linspace(self.start, self.stop, self.n_bins + 1) 26 elif self.scale == "log": 27 if self.start < 0: 28 raise ValueError( 29 f"start must be positive for log scale, got {self.start}" 30 ) 31 if self.start == 0: 32 return np.concatenate( 33 [ 34 np.array([0]), 35 np.logspace( 36 np.log10(self._log_min), np.log10(self.stop), self.n_bins 37 ), 38 ] 39 ) 40 elif self.start < self._log_min and self._zero_in_small_start_log: 41 return np.concatenate( 42 [ 43 np.array([0]), 44 np.logspace( 45 np.log10(self.start), np.log10(self.stop), self.n_bins 46 ), 47 ] 48 ) 49 else: 50 return np.logspace( 51 np.log10(self.start), np.log10(self.stop), self.n_bins + 1 52 ) 53 else: 54 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 55 56 @cached_property 57 def centers(self) -> Float[np.ndarray, "n_bins"]: 58 return (self.edges[:-1] + self.edges[1:]) / 2 59 60 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 61 return Bins( 62 n_bins=n_bins, 63 start=self.start, 64 stop=self.stop, 65 scale=self.scale, 66 _log_min=self._log_min, 67 _zero_in_small_start_log=self._zero_in_small_start_log, 68 )
Bins( n_bins: int = 32, start: float = 0, stop: float = 1.0, scale: Literal['lin', 'log'] = 'log', _log_min: float = 0.001, _zero_in_small_start_log: bool = True)
edges: jaxtyping.Float[ndarray, 'n_bins+1']
22 @cached_property 23 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 24 if self.scale == "lin": 25 return np.linspace(self.start, self.stop, self.n_bins + 1) 26 elif self.scale == "log": 27 if self.start < 0: 28 raise ValueError( 29 f"start must be positive for log scale, got {self.start}" 30 ) 31 if self.start == 0: 32 return np.concatenate( 33 [ 34 np.array([0]), 35 np.logspace( 36 np.log10(self._log_min), np.log10(self.stop), self.n_bins 37 ), 38 ] 39 ) 40 elif self.start < self._log_min and self._zero_in_small_start_log: 41 return np.concatenate( 42 [ 43 np.array([0]), 44 np.logspace( 45 np.log10(self.start), np.log10(self.stop), self.n_bins 46 ), 47 ] 48 ) 49 else: 50 return np.logspace( 51 np.log10(self.start), np.log10(self.stop), self.n_bins + 1 52 ) 53 else: 54 raise ValueError(f"Invalid scale {self.scale}, expected lin or log")