# Copyright (C) 2024 Matthew Jennings
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
from typing import Sequence
from pymedphys._imports import numba as nb, numpy as np, plt, scipy
[docs]
def plot_interp_comparison_heatmap(
values, values_interp, slice_axis: int, slice_number: int, slice_number_interp: int
):
"""
Plot a comparison heatmap of original and interpolated 3D data slices.
This function creates a side-by-side heatmap comparison of a slice from the
original data and a corresponding slice from the interpolated data.
Parameters
----------
values : array-like
The original 3D data array.
values_interp : array-like
The interpolated 3D data array.
slice_axis : int
The axis along which to take the slice (0, 1, or 2).
slice_number : int
The index of the slice to display from the original data.
slice_number_interp : int
The index of the slice to display from the interpolated data. Note that auto
matching between the original and interpolated slices is not implemented. The
user must ensure the slices correspond.
Returns
-------
None
This function displays the plot directly and does not return any value.
Notes
-----
- The function creates a figure with two subplots side by side.
- The left subplot shows the slice from the original data.
- The right subplot shows the slice from the interpolated data.
- It is up to the user to ensure that
- Both heatmaps use the same color scale, determined by the minimum and
maximum values across both datasets.
- A shared colorbar is displayed on the right side of the figure.
- The plot is automatically displayed using plt.show().
"""
_, (ax1, ax2) = plt.subplots(1, 2)
plot_min = min(values.min(), values_interp.min())
plot_max = max(values.max(), values_interp.max())
ax1.imshow(
values.take(axis=slice_axis, indices=slice_number), vmin=plot_min, vmax=plot_max
)
im2 = ax2.imshow(
values_interp.take(axis=slice_axis, indices=slice_number_interp),
vmin=plot_min,
vmax=plot_max,
)
plt.colorbar(im2, ax=(ax1, ax2), orientation="vertical")
plt.show()
def __check_inputs(
axes_known: Sequence["np.ndarray"],
values: "np.ndarray",
points_interp: "np.ndarray",
bounds_error=False,
) -> None:
if not 1 <= len(axes_known) == points_interp.shape[-1] <= 3:
raise ValueError(
f"axes_known (len {len(axes_known)}) and points_interp (len {points_interp.shape[-1]}) must have the same length; either 1, 2, or 3"
)
for i, axis_known in enumerate(axes_known):
if not axis_known.ndim == points_interp[i].ndim == 1:
raise ValueError(
f"axes_known[{i}] (shape {[{axis_known.shape}]}) and interp_structure[{i}] (shape {[{points_interp[i].shape}]}) must be 1D arrays"
)
if not axis_known.size == values.shape[i]:
raise ValueError(
f"axes_known[{i}] (size {axis_known.size}) must match the size of the corresponding dimension of values ({values.shape[i]})"
)
if bounds_error and (
points_interp[:, i].min() < axis_known.min()
or points_interp[:, i].max() > axis_known.max()
):
raise ValueError(
f"""points_interp[:, {i}] must be within the range of axes_known[{i}]\n
({points_interp[:, i].min()}, {points_interp[:, i].max()}) vs. ({axis_known.min()}, {axis_known.max()})"""
)
axes_known = tuple(np.array(axis, dtype=np.float64) for axis in axes_known)
values = np.array(values, dtype=np.float64)
return axes_known, values
@cache
def _get_interp_linear_1d():
@nb.njit(parallel=True, fastmath=True, cache=True)
def _interp_linear_1d(axis_known, values, points_interp, extrap_fill_value=np.nan):
values_interp = np.zeros(points_interp.shape[0], dtype=np.float64)
diff = axis_known[1] - axis_known[0]
# pylint: disable=not-an-iterable
for i in nb.prange(points_interp.shape[0]):
xpi = points_interp[i, 0]
if not axis_known[0] <= xpi <= axis_known[-1]:
values_interp[i] = extrap_fill_value
continue
x1_idx = np.searchsorted(axis_known, xpi)
x0_idx = x1_idx - 1
if x0_idx < 0:
x0_idx = 0
if x1_idx >= axis_known.size:
x1_idx = axis_known.size - 1
wx = (xpi - axis_known[x0_idx]) / diff
values_interp[i] = values[x0_idx] * (1 - wx) + values[x1_idx] * wx
return values_interp
return _interp_linear_1d
[docs]
def interp_linear_1d(axis_known, values, points_interp, extrap_fill_value=None):
_interp_linear_1d = _get_interp_linear_1d()
if extrap_fill_value is None:
extrap_fill_value = np.nan
return _interp_linear_1d(
axis_known=axis_known,
values=values,
points_interp=points_interp,
extrap_fill_value=extrap_fill_value,
)
@cache
def _get_interp_linear_2d():
@nb.njit(parallel=True, fastmath=True, cache=True)
def _interp_linear_2d(axes_known, values, points_interp, extrap_fill_value=np.nan):
values_interp = np.zeros((points_interp.shape[0]), dtype=np.float64)
diffs = np.zeros(2)
for i, axis in enumerate(axes_known):
diffs[i] = axis[1] - axis[0]
x, y = axes_known
# pylint: disable=not-an-iterable
for i in nb.prange(points_interp.shape[0]):
xpi, ypi = points_interp[i, 0], points_interp[i, 1]
if not x[0] <= xpi <= x[-1] or not y[0] <= ypi <= y[-1]:
values_interp[i] = extrap_fill_value
continue
# Find the indices of the surrounding grid points
x1_idx = np.searchsorted(x, xpi)
x0_idx = x1_idx - 1
y1_idx = np.searchsorted(y, ypi)
y0_idx = y1_idx - 1
if x0_idx < 0:
x0_idx = 0
if y0_idx < 0:
y0_idx = 0
if x1_idx >= x.size:
x1_idx = x.size - 1
if y1_idx >= y.size:
y1_idx = y.size - 1
c00 = values[x0_idx, y0_idx]
c01 = values[x0_idx, y1_idx]
c10 = values[x1_idx, y0_idx]
c11 = values[x1_idx, y1_idx]
wx = (xpi - x[x0_idx]) / diffs[0]
wy = (ypi - y[y0_idx]) / diffs[1]
c0 = c00 * (1 - wx) + c10 * wx
c1 = c01 * (1 - wx) + c11 * wx
values_interp[i] = c0 * (1 - wy) + c1 * wy
return values_interp
return _interp_linear_2d
[docs]
def interp_linear_2d(axes_known, values, points_interp, extrap_fill_value=None):
_interp_linear_2d = _get_interp_linear_2d()
if extrap_fill_value is None:
extrap_fill_value = np.nan
return _interp_linear_2d(
axes_known=axes_known,
values=values,
points_interp=points_interp,
extrap_fill_value=extrap_fill_value,
)
@cache
def _get_interp_linear_3d():
@nb.njit(parallel=True, fastmath=True, cache=True)
# pylint: disable=invalid-name
def _interp_linear_3d(axes_known, values, points_interp, extrap_fill_value=np.nan):
x, y, z = axes_known[0], axes_known[1], axes_known[2]
values_interp = np.zeros(
points_interp.shape[0],
dtype=np.float64,
)
diffs = np.zeros(3)
for i, axis in enumerate(axes_known):
diffs[i] = axis[1] - axis[0]
# pylint: disable=not-an-iterable
for i in nb.prange(points_interp.shape[0]):
xpi, ypi, zpi = (
points_interp[i, 0],
points_interp[i, 1],
points_interp[i, 2],
)
if (
not x[0] <= xpi <= x[-1]
or not y[0] <= ypi <= y[-1]
or not z[0] <= zpi <= z[-1]
):
values_interp[i] = extrap_fill_value
continue
# Find the indices of the surrounding grid points
x1_idx = np.searchsorted(x, xpi)
x0_idx = x1_idx - 1
y1_idx = np.searchsorted(y, ypi)
y0_idx = y1_idx - 1
z1_idx = np.searchsorted(z, zpi)
z0_idx = z1_idx - 1
if x0_idx < 0:
x0_idx = 0
if y0_idx < 0:
y0_idx = 0
if z0_idx < 0:
z0_idx = 0
if x1_idx >= x.size:
x1_idx = x.size - 1
if y1_idx >= y.size:
y1_idx = y.size - 1
if z1_idx >= z.size:
z1_idx = z.size - 1
# Compute interpolation weights
wx = (xpi - x[x0_idx]) / diffs[0]
wy = (ypi - y[y0_idx]) / diffs[1]
wz = (zpi - z[z0_idx]) / diffs[2]
# Extract values values at corner points
c000 = values[x0_idx, y0_idx, z0_idx]
c001 = values[x0_idx, y0_idx, z1_idx]
c010 = values[x0_idx, y1_idx, z0_idx]
c011 = values[x0_idx, y1_idx, z1_idx]
c100 = values[x1_idx, y0_idx, z0_idx]
c101 = values[x1_idx, y0_idx, z1_idx]
c110 = values[x1_idx, y1_idx, z0_idx]
c111 = values[x1_idx, y1_idx, z1_idx]
# Perform trilinear interpolation
c00 = c000 * (1 - wx) + c100 * wx
c01 = c001 * (1 - wx) + c101 * wx
c10 = c010 * (1 - wx) + c110 * wx
c11 = c011 * (1 - wx) + c111 * wx
c0 = c00 * (1 - wy) + c10 * wy
c1 = c01 * (1 - wy) + c11 * wy
values_interp[i] = c0 * (1 - wz) + c1 * wz
return values_interp
return _interp_linear_3d
[docs]
def interp_linear_3d(axes_known, values, points_interp, extrap_fill_value=None):
_interp_linear_3d = _get_interp_linear_3d()
if extrap_fill_value is None:
extrap_fill_value = np.nan
return _interp_linear_3d(
axes_known=axes_known,
values=values,
points_interp=points_interp,
extrap_fill_value=extrap_fill_value,
)
def interp_linear_scipy(
axes_known,
values,
axes_interp: Sequence["np.ndarray"] = None,
points_interp: "np.ndarray" = None,
keep_dims=False,
bounds_error=True,
extrap_fill_value=None,
):
if axes_interp is not None and points_interp is None:
mgrids = np.meshgrid(*axes_interp, indexing="ij")
points_interp = np.column_stack([mgrid.ravel() for mgrid in mgrids])
elif axes_interp is None and points_interp is not None:
pass
else:
raise ValueError(
"Exactly one of either `axes_interp` or `points_interp` must be specified"
)
if extrap_fill_value is None:
extrap_fill_value = np.nan
f = scipy.interpolate.RegularGridInterpolator(
axes_known, values, bounds_error=bounds_error, fill_value=extrap_fill_value
)
if keep_dims:
if axes_interp is not None:
return f(points_interp).reshape(axis.size for axis in axes_interp)
else:
raise ValueError(
"If `keep_dims` is True, `axes_interp` must be specified to determine the shape of the output"
)
else:
return f(points_interp)
# pylint: disable=invalid-name
[docs]
def interp(
axes_known: Sequence["np.ndarray"],
values: "np.ndarray",
axes_interp: Sequence["np.ndarray"] = None,
points_interp: "np.ndarray" = None,
keep_dims=False,
bounds_error=True,
extrap_fill_value=None,
skip_checks=False,
) -> "np.ndarray":
"""
Perform fast linear interpolation on 1D, 2D, or 3D data.
Parameters
----------
axes_known : Sequence[np.ndarray]
The coordinate vectors or axis coordinates of the known data points.
values : np.ndarray
The known values at the points defined by `axes_known`. Its shape should match
a tuple of the lengths of the axes in `axes_known` in the same order.
axes_interp : Sequence[np.ndarray], optional
The coordinate vectors or axis coordinates for which to interpolate values.
These axes will be expanded to flattened meshgrids and interpolation will occur
for each point in these grids. Either `axes_interp` or `points_interp` must be
provided, but not both.
points_interp : np.ndarray, optional
The exact coordinates of the points where interpolation is desired.
Shape should be (n, d) where n is the number of points and d is the number of
dimensions. Either `axes_interp` or `points_interp` must be provided, but not
both.
keep_dims : bool, optional
If True, return the interpolated values with the same shape as defined by
`axes_interp`. Only applicable when `axes_interp` is provided. Default is False.
bounds_error : bool, optional
If True, raise an error when interpolation is attempted outside the bounds of
the input data. Default is True.
extrap_fill_value : float, optional
The value to use for points outside the bounds of the input data when
`bounds_error` is False. Default is None, which results in using np.nan.
skip_checks : bool, optional
If True, skip input validation checks. Skipping these checks can produce a
significant improve in performance for some applications. Default is False.
Returns
-------
np.ndarray
The interpolated values. If `keep_dims` is True and `axes_interp` is provided,
the output will have the same shape as defined by `axes_interp`.
Otherwise, it will be a 1D array.
Raises
------
ValueError
If neither or both of `axes_interp` and `points_interp` are provided.
If `keep_dims` is True but `axes_interp` is not provided.
If the input axes are not monotonically increasing or evenly spaced.
Notes
-----
This function performs linear interpolation for 1D, 2D, or 3D data.
It supports both grid-based interpolation (using `axes_interp`) and
point-based interpolation (using `points_interp`).
The input axes must be monotonically increasing and evenly spaced.
"""
if axes_interp is not None and points_interp is None:
mgrids = np.meshgrid(*axes_interp, indexing="ij")
points_interp = np.column_stack([mgrid.ravel() for mgrid in mgrids])
elif axes_interp is None and points_interp is not None:
if keep_dims:
raise ValueError(
"If `keep_dims` is True, `axes_interp` must be specified to determine the shape of the output"
)
else:
raise ValueError(
"Exactly one of either `axes_interp` or `points_interp` must be specified"
)
if not skip_checks:
axes_known, values = __check_inputs(
axes_known, values, points_interp, bounds_error
)
axes_known_diffs = [np.diff(axis) for axis in axes_known]
# Handle ascending vs. descending vs. bad order.
for i, diff in enumerate(axes_known_diffs):
if not np.all(diff > 0):
raise ValueError(
f"axes_known[{i}] is not monotonically ascending or descending"
)
if not np.allclose(diff, diff[0]):
raise ValueError(f"axis_known[{i}] must be evenly spaced")
if extrap_fill_value is None:
extrap_fill_value = np.nan
if len(axes_known) == 1:
# keep_dims has no effect for 1D interpolation
return interp_linear_1d(
axes_known[0],
values,
points_interp,
extrap_fill_value,
)
elif len(axes_known) == 2:
values_interp = interp_linear_2d(
axes_known,
values,
points_interp,
extrap_fill_value,
)
else:
values_interp = interp_linear_3d(
axes_known,
values,
points_interp,
extrap_fill_value,
)
if keep_dims:
values_interp = values_interp.reshape([axis.size for axis in axes_interp])
return values_interp