docs for muutils v0.8.10
View Source on GitHub

muutils.statcounter

StatCounter class for counting and calculating statistics on numbers

cleaner and more efficient than just using a Counter or array


  1"""`StatCounter` class for counting and calculating statistics on numbers
  2
  3cleaner and more efficient than just using a `Counter` or array"""
  4
  5from __future__ import annotations
  6
  7import json
  8import math
  9from collections import Counter
 10from functools import cached_property
 11from itertools import chain
 12from typing import Callable, Optional, Sequence, Union
 13
 14
 15# _GeneralArray = Union[np.ndarray, "torch.Tensor"]
 16NumericSequence = Sequence[Union[float, int, "NumericSequence"]]
 17
 18# pylint: disable=abstract-method
 19
 20# misc
 21# ==================================================
 22
 23
 24def universal_flatten(
 25    arr: Union[NumericSequence, float, int], require_rectangular: bool = True
 26) -> NumericSequence:
 27    """flattens any iterable"""
 28
 29    # mypy complains that the sequence has no attribute "flatten"
 30    if hasattr(arr, "flatten") and callable(arr.flatten):  # type: ignore
 31        return arr.flatten()  # type: ignore
 32    elif isinstance(arr, Sequence):
 33        elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr]
 34        if require_rectangular and (all(elements_iterable) != any(elements_iterable)):
 35            raise ValueError("arr contains mixed iterable and non-iterable elements")
 36        if any(elements_iterable):
 37            return list(chain.from_iterable(universal_flatten(x) for x in arr))  # type: ignore[misc]
 38        else:
 39            return arr
 40    else:
 41        return [arr]
 42
 43
 44# StatCounter
 45# ==================================================
 46
 47
 48class StatCounter(Counter):
 49    """`Counter`, but with some stat calculation methods which assume the keys are numerical
 50
 51    works best when the keys are `int`s
 52    """
 53
 54    def validate(self) -> bool:
 55        """validate the counter as being all floats or ints"""
 56        return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())
 57
 58    def min(self):
 59        "minimum value"
 60        return min(x for x, v in self.items() if v > 0)
 61
 62    def max(self):
 63        "maximum value"
 64        return max(x for x, v in self.items() if v > 0)
 65
 66    def total(self):
 67        """Sum of the counts"""
 68        return sum(self.values())
 69
 70    @cached_property
 71    def keys_sorted(self) -> list:
 72        """return the keys"""
 73        return sorted(list(self.keys()))
 74
 75    def percentile(self, p: float):
 76        """return the value at the given percentile
 77
 78        this could be log time if we did binary search, but that would be a lot of added complexity
 79        """
 80
 81        if p < 0 or p > 1:
 82            raise ValueError(f"percentile must be between 0 and 1: {p}")
 83        # flip for speed
 84        sorted_keys: list[float] = [float(x) for x in self.keys_sorted]
 85        sort: int = 1
 86        if p > 0.51:
 87            sort = -1
 88            p = 1 - p
 89
 90        sorted_keys = sorted_keys[::sort]
 91        real_target: float = p * (self.total() - 1)
 92
 93        n_target_f: int = math.floor(real_target)
 94        n_target_c: int = math.ceil(real_target)
 95
 96        n_sofar: float = -1
 97
 98        # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }')
 99
100        for i, k in enumerate(sorted_keys):
101            n_sofar += self[k]
102
103            # print(f'{k = } {n_sofar = }')
104
105            if n_sofar > n_target_f:
106                return k
107
108            elif n_sofar == n_target_f:
109                if n_sofar == n_target_c:
110                    return k
111                else:
112                    # print(
113                    #     sorted_keys[i], (n_sofar + 1 - real_target),
114                    #     sorted_keys[i + 1], (real_target - n_sofar),
115                    # )
116                    return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[
117                        i + 1
118                    ] * (real_target - n_sofar)
119            else:
120                continue
121
122        raise ValueError(f"percentile {p} not found???")
123
124    def median(self) -> float:
125        return self.percentile(0.5)
126
127    def mean(self) -> float:
128        """return the mean of the values"""
129        return float(sum(k * c for k, c in self.items()) / self.total())
130
131    def mode(self) -> float:
132        return self.most_common()[0][0]
133
134    def std(self) -> float:
135        """return the standard deviation of the values"""
136        mean: float = self.mean()
137        deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items())
138
139        return (deviations / self.total()) ** 0.5
140
141    def summary(
142        self,
143        typecast: Callable = lambda x: x,
144        *,
145        extra_percentiles: Optional[list[float]] = None,
146    ) -> dict[str, Union[float, int]]:
147        """return a summary of the stats, without the raw data. human readable and small"""
148        # common stats that always work
149        output: dict = dict(
150            total_items=self.total(),
151            n_keys=len(self.keys()),
152            mode=self.mode(),
153        )
154
155        if self.total() > 0:
156            if self.validate():
157                # if its a numeric counter, we can do some stats
158                output = {
159                    **output,
160                    **dict(
161                        mean=float(self.mean()),
162                        std=float(self.std()),
163                        min=typecast(self.min()),
164                        q1=typecast(self.percentile(0.25)),
165                        median=typecast(self.median()),
166                        q3=typecast(self.percentile(0.75)),
167                        max=typecast(self.max()),
168                    ),
169                }
170
171                if extra_percentiles is not None:
172                    for p in extra_percentiles:
173                        output[f"percentile_{p}"] = typecast(self.percentile(p))
174            else:
175                # if its not, we can only do the simpler things
176                # mean mode and total are done in the initial declaration of `output`
177                pass
178
179        return output
180
181    def serialize(
182        self,
183        typecast: Callable = lambda x: x,
184        *,
185        extra_percentiles: Optional[list[float]] = None,
186    ) -> dict:
187        """return a json-serializable version of the counter
188
189        includes both the output of `summary` and the raw data:
190
191        ```json
192        {
193            "StatCounter": { <keys, values from raw data> },
194            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
195        }
196
197        """
198
199        return {
200            "StatCounter": {
201                typecast(k): v
202                for k, v in sorted(dict(self).items(), key=lambda x: x[0])
203            },
204            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
205        }
206
207    def __str__(self) -> str:
208        "summary as json with 2 space indent, good for printing"
209        return json.dumps(self.summary(), indent=2)
210
211    def __repr__(self) -> str:
212        return json.dumps(self.serialize(), indent=2)
213
214    @classmethod
215    def load(cls, data: dict) -> "StatCounter":
216        "load from a the output of `StatCounter.serialize`"
217        if "StatCounter" in data:
218            loadme = data["StatCounter"]
219        else:
220            loadme = data
221
222        return cls({float(k): v for k, v in loadme.items()})
223
224    @classmethod
225    def from_list_arrays(
226        cls,
227        arr,
228        map_func: Callable = float,
229    ) -> "StatCounter":
230        """calls `map_func` on each element of `universal_flatten(arr)`"""
231        return cls([map_func(x) for x in universal_flatten(arr)])

NumericSequence = typing.Sequence[typing.Union[float, int, ForwardRef('NumericSequence')]]
def universal_flatten( arr: Union[Sequence[Union[float, int, Sequence[Union[float, int, ForwardRef('NumericSequence')]]]], float, int], require_rectangular: bool = True) -> Sequence[Union[float, int, ForwardRef('NumericSequence')]]:
25def universal_flatten(
26    arr: Union[NumericSequence, float, int], require_rectangular: bool = True
27) -> NumericSequence:
28    """flattens any iterable"""
29
30    # mypy complains that the sequence has no attribute "flatten"
31    if hasattr(arr, "flatten") and callable(arr.flatten):  # type: ignore
32        return arr.flatten()  # type: ignore
33    elif isinstance(arr, Sequence):
34        elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr]
35        if require_rectangular and (all(elements_iterable) != any(elements_iterable)):
36            raise ValueError("arr contains mixed iterable and non-iterable elements")
37        if any(elements_iterable):
38            return list(chain.from_iterable(universal_flatten(x) for x in arr))  # type: ignore[misc]
39        else:
40            return arr
41    else:
42        return [arr]

flattens any iterable

class StatCounter(collections.Counter):
 49class StatCounter(Counter):
 50    """`Counter`, but with some stat calculation methods which assume the keys are numerical
 51
 52    works best when the keys are `int`s
 53    """
 54
 55    def validate(self) -> bool:
 56        """validate the counter as being all floats or ints"""
 57        return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())
 58
 59    def min(self):
 60        "minimum value"
 61        return min(x for x, v in self.items() if v > 0)
 62
 63    def max(self):
 64        "maximum value"
 65        return max(x for x, v in self.items() if v > 0)
 66
 67    def total(self):
 68        """Sum of the counts"""
 69        return sum(self.values())
 70
 71    @cached_property
 72    def keys_sorted(self) -> list:
 73        """return the keys"""
 74        return sorted(list(self.keys()))
 75
 76    def percentile(self, p: float):
 77        """return the value at the given percentile
 78
 79        this could be log time if we did binary search, but that would be a lot of added complexity
 80        """
 81
 82        if p < 0 or p > 1:
 83            raise ValueError(f"percentile must be between 0 and 1: {p}")
 84        # flip for speed
 85        sorted_keys: list[float] = [float(x) for x in self.keys_sorted]
 86        sort: int = 1
 87        if p > 0.51:
 88            sort = -1
 89            p = 1 - p
 90
 91        sorted_keys = sorted_keys[::sort]
 92        real_target: float = p * (self.total() - 1)
 93
 94        n_target_f: int = math.floor(real_target)
 95        n_target_c: int = math.ceil(real_target)
 96
 97        n_sofar: float = -1
 98
 99        # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }')
100
101        for i, k in enumerate(sorted_keys):
102            n_sofar += self[k]
103
104            # print(f'{k = } {n_sofar = }')
105
106            if n_sofar > n_target_f:
107                return k
108
109            elif n_sofar == n_target_f:
110                if n_sofar == n_target_c:
111                    return k
112                else:
113                    # print(
114                    #     sorted_keys[i], (n_sofar + 1 - real_target),
115                    #     sorted_keys[i + 1], (real_target - n_sofar),
116                    # )
117                    return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[
118                        i + 1
119                    ] * (real_target - n_sofar)
120            else:
121                continue
122
123        raise ValueError(f"percentile {p} not found???")
124
125    def median(self) -> float:
126        return self.percentile(0.5)
127
128    def mean(self) -> float:
129        """return the mean of the values"""
130        return float(sum(k * c for k, c in self.items()) / self.total())
131
132    def mode(self) -> float:
133        return self.most_common()[0][0]
134
135    def std(self) -> float:
136        """return the standard deviation of the values"""
137        mean: float = self.mean()
138        deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items())
139
140        return (deviations / self.total()) ** 0.5
141
142    def summary(
143        self,
144        typecast: Callable = lambda x: x,
145        *,
146        extra_percentiles: Optional[list[float]] = None,
147    ) -> dict[str, Union[float, int]]:
148        """return a summary of the stats, without the raw data. human readable and small"""
149        # common stats that always work
150        output: dict = dict(
151            total_items=self.total(),
152            n_keys=len(self.keys()),
153            mode=self.mode(),
154        )
155
156        if self.total() > 0:
157            if self.validate():
158                # if its a numeric counter, we can do some stats
159                output = {
160                    **output,
161                    **dict(
162                        mean=float(self.mean()),
163                        std=float(self.std()),
164                        min=typecast(self.min()),
165                        q1=typecast(self.percentile(0.25)),
166                        median=typecast(self.median()),
167                        q3=typecast(self.percentile(0.75)),
168                        max=typecast(self.max()),
169                    ),
170                }
171
172                if extra_percentiles is not None:
173                    for p in extra_percentiles:
174                        output[f"percentile_{p}"] = typecast(self.percentile(p))
175            else:
176                # if its not, we can only do the simpler things
177                # mean mode and total are done in the initial declaration of `output`
178                pass
179
180        return output
181
182    def serialize(
183        self,
184        typecast: Callable = lambda x: x,
185        *,
186        extra_percentiles: Optional[list[float]] = None,
187    ) -> dict:
188        """return a json-serializable version of the counter
189
190        includes both the output of `summary` and the raw data:
191
192        ```json
193        {
194            "StatCounter": { <keys, values from raw data> },
195            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
196        }
197
198        """
199
200        return {
201            "StatCounter": {
202                typecast(k): v
203                for k, v in sorted(dict(self).items(), key=lambda x: x[0])
204            },
205            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
206        }
207
208    def __str__(self) -> str:
209        "summary as json with 2 space indent, good for printing"
210        return json.dumps(self.summary(), indent=2)
211
212    def __repr__(self) -> str:
213        return json.dumps(self.serialize(), indent=2)
214
215    @classmethod
216    def load(cls, data: dict) -> "StatCounter":
217        "load from a the output of `StatCounter.serialize`"
218        if "StatCounter" in data:
219            loadme = data["StatCounter"]
220        else:
221            loadme = data
222
223        return cls({float(k): v for k, v in loadme.items()})
224
225    @classmethod
226    def from_list_arrays(
227        cls,
228        arr,
229        map_func: Callable = float,
230    ) -> "StatCounter":
231        """calls `map_func` on each element of `universal_flatten(arr)`"""
232        return cls([map_func(x) for x in universal_flatten(arr)])

Counter, but with some stat calculation methods which assume the keys are numerical

works best when the keys are ints

def validate(self) -> bool:
55    def validate(self) -> bool:
56        """validate the counter as being all floats or ints"""
57        return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())

validate the counter as being all floats or ints

def min(self):
59    def min(self):
60        "minimum value"
61        return min(x for x, v in self.items() if v > 0)

minimum value

def max(self):
63    def max(self):
64        "maximum value"
65        return max(x for x, v in self.items() if v > 0)

maximum value

def total(self):
67    def total(self):
68        """Sum of the counts"""
69        return sum(self.values())

Sum of the counts

keys_sorted: list
71    @cached_property
72    def keys_sorted(self) -> list:
73        """return the keys"""
74        return sorted(list(self.keys()))

return the keys

def percentile(self, p: float):
 76    def percentile(self, p: float):
 77        """return the value at the given percentile
 78
 79        this could be log time if we did binary search, but that would be a lot of added complexity
 80        """
 81
 82        if p < 0 or p > 1:
 83            raise ValueError(f"percentile must be between 0 and 1: {p}")
 84        # flip for speed
 85        sorted_keys: list[float] = [float(x) for x in self.keys_sorted]
 86        sort: int = 1
 87        if p > 0.51:
 88            sort = -1
 89            p = 1 - p
 90
 91        sorted_keys = sorted_keys[::sort]
 92        real_target: float = p * (self.total() - 1)
 93
 94        n_target_f: int = math.floor(real_target)
 95        n_target_c: int = math.ceil(real_target)
 96
 97        n_sofar: float = -1
 98
 99        # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }')
100
101        for i, k in enumerate(sorted_keys):
102            n_sofar += self[k]
103
104            # print(f'{k = } {n_sofar = }')
105
106            if n_sofar > n_target_f:
107                return k
108
109            elif n_sofar == n_target_f:
110                if n_sofar == n_target_c:
111                    return k
112                else:
113                    # print(
114                    #     sorted_keys[i], (n_sofar + 1 - real_target),
115                    #     sorted_keys[i + 1], (real_target - n_sofar),
116                    # )
117                    return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[
118                        i + 1
119                    ] * (real_target - n_sofar)
120            else:
121                continue
122
123        raise ValueError(f"percentile {p} not found???")

return the value at the given percentile

this could be log time if we did binary search, but that would be a lot of added complexity

def median(self) -> float:
125    def median(self) -> float:
126        return self.percentile(0.5)
def mean(self) -> float:
128    def mean(self) -> float:
129        """return the mean of the values"""
130        return float(sum(k * c for k, c in self.items()) / self.total())

return the mean of the values

def mode(self) -> float:
132    def mode(self) -> float:
133        return self.most_common()[0][0]
def std(self) -> float:
135    def std(self) -> float:
136        """return the standard deviation of the values"""
137        mean: float = self.mean()
138        deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items())
139
140        return (deviations / self.total()) ** 0.5

return the standard deviation of the values

def summary( self, typecast: Callable = <function StatCounter.<lambda>>, *, extra_percentiles: Optional[list[float]] = None) -> dict[str, typing.Union[float, int]]:
142    def summary(
143        self,
144        typecast: Callable = lambda x: x,
145        *,
146        extra_percentiles: Optional[list[float]] = None,
147    ) -> dict[str, Union[float, int]]:
148        """return a summary of the stats, without the raw data. human readable and small"""
149        # common stats that always work
150        output: dict = dict(
151            total_items=self.total(),
152            n_keys=len(self.keys()),
153            mode=self.mode(),
154        )
155
156        if self.total() > 0:
157            if self.validate():
158                # if its a numeric counter, we can do some stats
159                output = {
160                    **output,
161                    **dict(
162                        mean=float(self.mean()),
163                        std=float(self.std()),
164                        min=typecast(self.min()),
165                        q1=typecast(self.percentile(0.25)),
166                        median=typecast(self.median()),
167                        q3=typecast(self.percentile(0.75)),
168                        max=typecast(self.max()),
169                    ),
170                }
171
172                if extra_percentiles is not None:
173                    for p in extra_percentiles:
174                        output[f"percentile_{p}"] = typecast(self.percentile(p))
175            else:
176                # if its not, we can only do the simpler things
177                # mean mode and total are done in the initial declaration of `output`
178                pass
179
180        return output

return a summary of the stats, without the raw data. human readable and small

def serialize( self, typecast: Callable = <function StatCounter.<lambda>>, *, extra_percentiles: Optional[list[float]] = None) -> dict:
182    def serialize(
183        self,
184        typecast: Callable = lambda x: x,
185        *,
186        extra_percentiles: Optional[list[float]] = None,
187    ) -> dict:
188        """return a json-serializable version of the counter
189
190        includes both the output of `summary` and the raw data:
191
192        ```json
193        {
194            "StatCounter": { <keys, values from raw data> },
195            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
196        }
197
198        """
199
200        return {
201            "StatCounter": {
202                typecast(k): v
203                for k, v in sorted(dict(self).items(), key=lambda x: x[0])
204            },
205            "summary": self.summary(typecast, extra_percentiles=extra_percentiles),
206        }

return a json-serializable version of the counter

includes both the output of summary and the raw data:

```json { "StatCounter": { }, "summary": self.summary(typecast, extra_percentiles=extra_percentiles), }

@classmethod
def load(cls, data: dict) -> StatCounter:
215    @classmethod
216    def load(cls, data: dict) -> "StatCounter":
217        "load from a the output of `StatCounter.serialize`"
218        if "StatCounter" in data:
219            loadme = data["StatCounter"]
220        else:
221            loadme = data
222
223        return cls({float(k): v for k, v in loadme.items()})

load from a the output of StatCounter.serialize

@classmethod
def from_list_arrays( cls, arr, map_func: Callable = <class 'float'>) -> StatCounter:
225    @classmethod
226    def from_list_arrays(
227        cls,
228        arr,
229        map_func: Callable = float,
230    ) -> "StatCounter":
231        """calls `map_func` on each element of `universal_flatten(arr)`"""
232        return cls([map_func(x) for x in universal_flatten(arr)])

calls map_func on each element of universal_flatten(arr)

Inherited Members
collections.Counter
Counter
most_common
elements
fromkeys
update
subtract
copy
builtins.dict
get
setdefault
pop
popitem
keys
items
values
clear