muutils.math.matrix_powers
1from __future__ import annotations 2 3from typing import List, Sequence, TYPE_CHECKING 4 5import numpy as np 6from jaxtyping import Float, Int 7 8if TYPE_CHECKING: 9 pass 10 11 12def matrix_powers( 13 A: Float[np.ndarray, "n n"], 14 powers: Sequence[int], 15) -> Float[np.ndarray, "n_powers n n"]: 16 """Compute multiple powers of a matrix efficiently. 17 18 Uses binary exponentiation to compute powers in O(log max(powers)) 19 matrix multiplications, avoiding redundant calculations when 20 computing multiple powers. 21 22 # Parameters: 23 - `A : Float[np.ndarray, "n n"]` 24 Square matrix to exponentiate 25 - `powers : Sequence[int]` 26 List of powers to compute (non-negative integers) 27 28 # Returns: 29 - `dict[int, Float[np.ndarray, "n n"]]` 30 Dictionary mapping each requested power to the corresponding matrix power 31 """ 32 dim_n: int = A.shape[0] 33 assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }" 34 powers_np: Int[np.ndarray, "n_powers_unique"] = np.array( 35 sorted(set(powers)), dtype=int 36 ) 37 n_powers_unique: int = len(powers_np) 38 39 if n_powers_unique < 1: 40 raise ValueError(f"No powers requested: {powers = }") 41 42 output: Float[np.ndarray, "n_powers_unique n n"] = np.full( 43 (n_powers_unique, dim_n, dim_n), 44 fill_value=np.nan, 45 dtype=A.dtype, 46 ) 47 48 # Find the maximum power to compute 49 max_power: int = max(powers_np) 50 51 # Precompute all powers of 2 up to the largest power needed 52 # This forms our basis for binary decomposition 53 powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {} 54 powers_of_two[0] = np.eye(dim_n, dtype=A.dtype) 55 powers_of_two[1] = A.copy() 56 57 # Compute powers of 2: A^2, A^4, A^8, ... 58 p: int = 1 59 while p < max_power: 60 if p <= max_power: 61 A_power_p = powers_of_two[p] 62 powers_of_two[p * 2] = A_power_p @ A_power_p 63 p = p * 2 64 65 # For each requested power, compute it using the powers of 2 66 for p_idx, power in enumerate(powers_np): 67 # Decompose power into sum of powers of 2 68 temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy() 69 temp_power: int = power 70 p_temp: int = 1 71 72 while temp_power > 0: 73 if temp_power % 2 == 1: 74 temp_result = temp_result @ powers_of_two[p_temp] 75 temp_power = temp_power // 2 76 p_temp *= 2 77 78 output[p_idx] = temp_result 79 80 return output 81 82 83# BUG: breaks with integer matrices??? 84# TYPING: jaxtyping hints not working here, separate file for torch implementation? 85def matrix_powers_torch( 86 A, # : Float["torch.Tensor", "n n"], 87 powers: Sequence[int], 88): # Float["torch.Tensor", "n_powers n n"]: 89 """Compute multiple powers of a matrix efficiently. 90 91 Uses binary exponentiation to compute powers in O(log max(powers)) 92 matrix multiplications, avoiding redundant calculations when 93 computing multiple powers. 94 95 # Parameters: 96 - `A : Float[torch.Tensor, "n n"]` 97 Square matrix to exponentiate 98 - `powers : Sequence[int]` 99 List of powers to compute (non-negative integers) 100 101 # Returns: 102 - `Float[torch.Tensor, "n_powers n n"]` 103 Tensor containing the requested matrix powers stacked along the first dimension 104 105 # Raises: 106 - `ValueError` : If no powers are requested or if A is not a square matrix 107 """ 108 109 import torch 110 111 if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 112 raise ValueError(f"Matrix must be square, but got {A.shape = }") 113 114 dim_n: int = A.shape[0] 115 # Get unique powers and sort them 116 unique_powers: List[int] = sorted(set(powers)) 117 n_powers_unique: int = len(unique_powers) 118 powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor( 119 unique_powers, dtype=torch.int64, device=A.device 120 ) 121 122 if n_powers_unique < 1: 123 raise ValueError(f"No powers requested: {powers = }") 124 125 output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full( 126 (n_powers_unique, dim_n, dim_n), 127 float("nan"), 128 dtype=A.dtype, 129 device=A.device, 130 ) 131 132 # Find the maximum power to compute 133 max_power: int = int(powers_tensor.max().item()) 134 135 # Precompute all powers of 2 up to the largest power needed 136 # This forms our basis for binary decomposition 137 powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {} 138 powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device) 139 powers_of_two[1] = A.clone() 140 141 # Compute powers of 2: A^2, A^4, A^8, ... 142 p: int = 1 143 while p < max_power: 144 if p <= max_power: 145 A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p] 146 powers_of_two[p * 2] = A_power_p @ A_power_p 147 p = p * 2 148 149 # For each requested power, compute it using the powers of 2 150 for p_idx, power in enumerate(unique_powers): 151 # Decompose power into sum of powers of 2 152 temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone() 153 temp_power: int = power 154 p_temp: int = 1 155 156 while temp_power > 0: 157 if temp_power % 2 == 1: 158 temp_result = temp_result @ powers_of_two[p_temp] 159 temp_power = temp_power // 2 160 p_temp *= 2 161 162 output[p_idx] = temp_result 163 164 return output
def
matrix_powers( A: jaxtyping.Float[ndarray, 'n n'], powers: Sequence[int]) -> jaxtyping.Float[ndarray, 'n_powers n n']:
13def matrix_powers( 14 A: Float[np.ndarray, "n n"], 15 powers: Sequence[int], 16) -> Float[np.ndarray, "n_powers n n"]: 17 """Compute multiple powers of a matrix efficiently. 18 19 Uses binary exponentiation to compute powers in O(log max(powers)) 20 matrix multiplications, avoiding redundant calculations when 21 computing multiple powers. 22 23 # Parameters: 24 - `A : Float[np.ndarray, "n n"]` 25 Square matrix to exponentiate 26 - `powers : Sequence[int]` 27 List of powers to compute (non-negative integers) 28 29 # Returns: 30 - `dict[int, Float[np.ndarray, "n n"]]` 31 Dictionary mapping each requested power to the corresponding matrix power 32 """ 33 dim_n: int = A.shape[0] 34 assert A.shape[0] == A.shape[1], f"Matrix must be square, but got {A.shape = }" 35 powers_np: Int[np.ndarray, "n_powers_unique"] = np.array( 36 sorted(set(powers)), dtype=int 37 ) 38 n_powers_unique: int = len(powers_np) 39 40 if n_powers_unique < 1: 41 raise ValueError(f"No powers requested: {powers = }") 42 43 output: Float[np.ndarray, "n_powers_unique n n"] = np.full( 44 (n_powers_unique, dim_n, dim_n), 45 fill_value=np.nan, 46 dtype=A.dtype, 47 ) 48 49 # Find the maximum power to compute 50 max_power: int = max(powers_np) 51 52 # Precompute all powers of 2 up to the largest power needed 53 # This forms our basis for binary decomposition 54 powers_of_two: dict[int, Float[np.ndarray, "n n"]] = {} 55 powers_of_two[0] = np.eye(dim_n, dtype=A.dtype) 56 powers_of_two[1] = A.copy() 57 58 # Compute powers of 2: A^2, A^4, A^8, ... 59 p: int = 1 60 while p < max_power: 61 if p <= max_power: 62 A_power_p = powers_of_two[p] 63 powers_of_two[p * 2] = A_power_p @ A_power_p 64 p = p * 2 65 66 # For each requested power, compute it using the powers of 2 67 for p_idx, power in enumerate(powers_np): 68 # Decompose power into sum of powers of 2 69 temp_result: Float[np.ndarray, "n n"] = powers_of_two[0].copy() 70 temp_power: int = power 71 p_temp: int = 1 72 73 while temp_power > 0: 74 if temp_power % 2 == 1: 75 temp_result = temp_result @ powers_of_two[p_temp] 76 temp_power = temp_power // 2 77 p_temp *= 2 78 79 output[p_idx] = temp_result 80 81 return output
Compute multiple powers of a matrix efficiently.
Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.
Parameters:
A : Float[np.ndarray, "n n"]
Square matrix to exponentiatepowers : Sequence[int]
List of powers to compute (non-negative integers)
Returns:
dict[int, Float[np.ndarray, "n n"]]
Dictionary mapping each requested power to the corresponding matrix power
def
matrix_powers_torch(A, powers: Sequence[int]):
86def matrix_powers_torch( 87 A, # : Float["torch.Tensor", "n n"], 88 powers: Sequence[int], 89): # Float["torch.Tensor", "n_powers n n"]: 90 """Compute multiple powers of a matrix efficiently. 91 92 Uses binary exponentiation to compute powers in O(log max(powers)) 93 matrix multiplications, avoiding redundant calculations when 94 computing multiple powers. 95 96 # Parameters: 97 - `A : Float[torch.Tensor, "n n"]` 98 Square matrix to exponentiate 99 - `powers : Sequence[int]` 100 List of powers to compute (non-negative integers) 101 102 # Returns: 103 - `Float[torch.Tensor, "n_powers n n"]` 104 Tensor containing the requested matrix powers stacked along the first dimension 105 106 # Raises: 107 - `ValueError` : If no powers are requested or if A is not a square matrix 108 """ 109 110 import torch 111 112 if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 113 raise ValueError(f"Matrix must be square, but got {A.shape = }") 114 115 dim_n: int = A.shape[0] 116 # Get unique powers and sort them 117 unique_powers: List[int] = sorted(set(powers)) 118 n_powers_unique: int = len(unique_powers) 119 powers_tensor: Int[torch.Tensor, "n_powers_unique"] = torch.tensor( 120 unique_powers, dtype=torch.int64, device=A.device 121 ) 122 123 if n_powers_unique < 1: 124 raise ValueError(f"No powers requested: {powers = }") 125 126 output: Float[torch.Tensor, "n_powers_unique n n"] = torch.full( 127 (n_powers_unique, dim_n, dim_n), 128 float("nan"), 129 dtype=A.dtype, 130 device=A.device, 131 ) 132 133 # Find the maximum power to compute 134 max_power: int = int(powers_tensor.max().item()) 135 136 # Precompute all powers of 2 up to the largest power needed 137 # This forms our basis for binary decomposition 138 powers_of_two: dict[int, Float[torch.Tensor, "n n"]] = {} 139 powers_of_two[0] = torch.eye(dim_n, dtype=A.dtype, device=A.device) 140 powers_of_two[1] = A.clone() 141 142 # Compute powers of 2: A^2, A^4, A^8, ... 143 p: int = 1 144 while p < max_power: 145 if p <= max_power: 146 A_power_p: Float[torch.Tensor, "n n"] = powers_of_two[p] 147 powers_of_two[p * 2] = A_power_p @ A_power_p 148 p = p * 2 149 150 # For each requested power, compute it using the powers of 2 151 for p_idx, power in enumerate(unique_powers): 152 # Decompose power into sum of powers of 2 153 temp_result: Float[torch.Tensor, "n n"] = powers_of_two[0].clone() 154 temp_power: int = power 155 p_temp: int = 1 156 157 while temp_power > 0: 158 if temp_power % 2 == 1: 159 temp_result = temp_result @ powers_of_two[p_temp] 160 temp_power = temp_power // 2 161 p_temp *= 2 162 163 output[p_idx] = temp_result 164 165 return output
Compute multiple powers of a matrix efficiently.
Uses binary exponentiation to compute powers in O(log max(powers)) matrix multiplications, avoiding redundant calculations when computing multiple powers.
Parameters:
A : Float[torch.Tensor, "n n"]
Square matrix to exponentiatepowers : Sequence[int]
List of powers to compute (non-negative integers)
Returns:
Float[torch.Tensor, "n_powers n n"]
Tensor containing the requested matrix powers stacked along the first dimension
Raises:
ValueError
: If no powers are requested or if A is not a square matrix