Source code for glass.arraytools

"""Module for array utilities."""

from __future__ import annotations

import typing
from typing import TYPE_CHECKING

import array_api_compat
import array_api_extra as xpx

from glass._array_api_utils import xp_additions as uxpx

if TYPE_CHECKING:
    from types import ModuleType

    from glass._types import AnyArray, FloatArray, IntArray


[docs] def broadcast_first( *arrays: FloatArray, ) -> tuple[FloatArray, ...]: """ Broadcast arrays, treating the first axis as common. Parameters ---------- arrays The arrays to broadcast. Returns ------- The broadcasted arrays. """ xp = array_api_compat.array_namespace(*arrays, use_compat=False) arrays = tuple(xp.moveaxis(a, 0, -1) if a.ndim else a for a in arrays) arrays = xp.broadcast_arrays(*arrays) return tuple(xp.moveaxis(a, -1, 0) if a.ndim else a for a in arrays)
[docs] def broadcast_leading_axes( *args: tuple[ float | FloatArray, int, ], xp: ModuleType | None = None, ) -> tuple[ tuple[int, ...], *tuple[FloatArray, ...], ]: """ Broadcast all but the last N axes. Parameters ---------- args The arrays and the number of axes to keep. xp The array library backend to use for array operations. Returns ------- The shape of the broadcast dimensions, and all input arrays with leading axes matching that shape. Examples -------- Broadcast all dimensions of ``a``, all except the last dimension of ``b``, and all except the last two dimensions of ``c``. >>> import numpy as np >>> a = 0 >>> b = np.zeros((4, 10)) >>> c = np.zeros((3, 1, 5, 6)) >>> dims, a, b, c = broadcast_leading_axes((a, 0), (b, 1), (c, 2)) >>> dims (3, 4) >>> a.shape (3, 4) >>> b.shape (3, 4, 10) >>> c.shape (3, 4, 5, 6) """ if xp is None: xp = array_api_compat.array_namespace( *[arg[0] for arg in args], use_compat=False, ) shapes, trails = [], [] for a, n in args: a_arr = xp.asarray(a) s = a_arr.shape i = len(s) - n shapes.append(s[:i]) trails.append(s[i:]) dims = xpx.broadcast_shapes(*shapes) arrs = ( xp.broadcast_to(xp.asarray(a), dims + t) for (a, _), t in zip(args, trails, strict=False) ) return (dims, *arrs) # ty: ignore[invalid-return-type]
[docs] def ndinterp( # noqa: PLR0913 x: float | FloatArray, xq: FloatArray, fq: FloatArray, axis: int = -1, left: float | None = None, right: float | None = None, period: float | None = None, ) -> FloatArray: """ Interpolate multi-dimensional array over axis. Parameters ---------- x The x-coordinates. xq The x-coordinates of the data points. fq The function values corresponding to the x-coordinates in *xq*. axis The axis to interpolate over. left The value to return for x < xq[0]. right The value to return for x > xq[-1]. period The period of the function, used for interpolating periodic data. Returns ------- The interpolated array. """ return uxpx.apply_along_axis( uxpx.interp, (x, xq), axis, fq, left=left, right=right, period=period, )
[docs] def trapezoid_product( f: tuple[FloatArray, FloatArray], *ff: tuple[FloatArray, FloatArray], axis: int = -1, ) -> float | FloatArray: """ Trapezoidal rule for a product of functions. Parameters ---------- f The first function. ff The other functions. axis The axis along which to integrate. Returns ------- The integral of the product of the functions. """ # Flatten ff into a 1D tuple of all ff inputs and then expand to get the namespace x: FloatArray x, _ = f for x_, _ in ff: x = typing.cast( "FloatArray", xpx.union1d( x[(x >= x_[0]) & (x <= x_[-1])], x_[(x_ >= x[0]) & (x_ <= x[-1])], ), ) y = uxpx.interp(x, *f) for f_ in ff: y *= uxpx.interp(x, *f_) return uxpx.trapezoid(y, x, axis=axis)
[docs] def cumulative_trapezoid( f: IntArray | FloatArray, x: IntArray | FloatArray, ) -> AnyArray: """ Cumulative trapezoidal rule along last axis. Parameters ---------- f The function values. x The x-coordinates. Returns ------- The cumulative integral of the function. """ xp = array_api_compat.array_namespace(f, x, use_compat=False) f = xp.asarray(f, dtype=xp.float64) x = xp.asarray(x, dtype=xp.float64) # Compute the cumulative trapezoid without mutating any arrays return xp.cumulative_sum( (f[..., 1:] + f[..., :-1]) * 0.5 * xp.diff(x), axis=-1, include_initial=True, )