Skip to content

Torch Find Peaks Documentation

Welcome to the documentation for the torch-find-peaks library.

Overview

The torch-find-peaks library provides utilities for detecting and refining peaks in 2D and 3D data using PyTorch. It includes methods for peak detection, Gaussian fitting, and more.

Installation

To install the library, use:

pip install torch-find-peaks

Usage

Here are some of the key functionalities provided by the library:

  • Peak Detection: Detect peaks in 2D images or 3D volumes.
  • Gaussian Fitting: Fit 2D or 3D Gaussian functions to refine peak positions.

API Reference

torch_find_peaks.find_peaks

find_peaks_2d(image, min_distance=1, threshold_abs=0.0, exclude_border=0, return_as='torch')

Find local peaks in a 2D image.

Accepts various input types (torch.Tensor, numpy.ndarray) and attempts to convert them to torch.Tensor before processing.

Parameters:

Name Type Description Default
image Any

A 2D tensor-like object (e.g., torch.Tensor, numpy.ndarray) representing the input image.

required
min_distance int

Minimum distance between peaks. Default is 1.

1
threshold_abs float

Minimum intensity value for a peak to be considered. Default is 0.0.

0.0
exclude_border int

Width of the border to exclude from peak detection. Default is 0.

0
return_as str

The format of the output. Default is "torch". Other options are "numpy" and "dataframe".

'torch'

Returns:

Type Description
Tensor

A tensor of shape (N, 2), where N is the number of peaks, and each row contains the (Y, X) coordinates of a peak.

Raises:

Type Description
TypeError

If the input image cannot be converted to a torch.Tensor.

ValueError

If the input image is not 2-dimensional after conversion.

Source code in src/torch_find_peaks/find_peaks.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def find_peaks_2d(
        image: Union[torch.Tensor, np.ndarray], 
        min_distance: int = 1,
        threshold_abs: float = 0.0,
        exclude_border: int = 0,
        return_as: Literal["torch","numpy","dataframe"] = "torch",
) -> torch.Tensor:
    """
    Find local peaks in a 2D image.

    Accepts various input types (torch.Tensor, numpy.ndarray) and attempts
    to convert them to torch.Tensor before processing.

    Parameters
    ----------
    image : Any
        A 2D tensor-like object (e.g., torch.Tensor, numpy.ndarray)
        representing the input image.
    min_distance : int, optional
        Minimum distance between peaks. Default is 1.
    threshold_abs : float, optional
        Minimum intensity value for a peak to be considered. Default is 0.0.
    exclude_border : int, optional
        Width of the border to exclude from peak detection. Default is 0.
    return_as : str, optional
        The format of the output. Default is "torch".
        Other options are "numpy" and "dataframe".

    Returns
    -------
    torch.Tensor
        A tensor of shape (N, 2), where N is the number of peaks, and each row
        contains the (Y, X) coordinates of a peak.

    Raises
    ------
    TypeError
        If the input image cannot be converted to a torch.Tensor.
    ValueError
        If the input image is not 2-dimensional after conversion.
    """
    if isinstance(image, torch.Tensor):
        image_tensor = image
    elif isinstance(image, np.ndarray):
        image_tensor = torch.from_numpy(image)
    # Add checks for pandas/polars DataFrames/Series here if needed
    # elif pd and isinstance(image, pd.DataFrame):
    #     image_tensor = torch.from_numpy(image.values)
    # elif pl and isinstance(image, pl.DataFrame):
    #     image_tensor = torch.from_numpy(image.to_numpy())
    else:
        try:
            # Attempt a general conversion for other array-like objects
            image_tensor = torch.as_tensor(image)
        except Exception as e:
            raise TypeError(
                f"Input type {type(image)} not supported or conversion failed: {e}"
            )

    if image_tensor.ndim != 2:
        raise ValueError(
            f"Input image must be 2-dimensional, but got {image_tensor.ndim} dimensions."
        )

    found_peaks, heights = _find_peaks_2d_torch(
        image=image_tensor,
        min_distance=min_distance,
        threshold_abs=threshold_abs,
        exclude_border=exclude_border,
    )

    if return_as == "torch":
        return found_peaks, heights
    elif return_as == "numpy":
        return found_peaks.numpy(), heights.numpy()
    elif return_as == "dataframe":
        # Use einops.pack to properly handle tensors with different dimensions
        # First tensor has shape [N, 2], second has shape [N]
        # We're packing them along the second dimension (dim=1)
        packed, _ = einops.pack([found_peaks, heights], 'n *')
        return pd.DataFrame(packed.cpu(), columns=["y", "x", "height"])
    else:
        raise ValueError(f"Invalid return_as value: {return_as}")

find_peaks_3d(volume, min_distance=1, threshold_abs=0.0, exclude_border=0, return_as='torch')

Find local peaks in a 3D volume.

Accepts various input types (torch.Tensor, numpy.ndarray) and attempts to convert them to torch.Tensor before processing.

Parameters:

Name Type Description Default
volume Any

A 3D tensor-like object (e.g., torch.Tensor, numpy.ndarray) representing the input volume.

required
min_distance int

Minimum distance between peaks. Default is 1.

1
threshold_abs float

Minimum intensity value for a peak to be considered. Default is 0.0.

0.0
exclude_border int

Width of the border to exclude from peak detection. Default is 0.

0
return_as str

The format of the output. Default is "torch". Other options are "numpy" and "dataframe".

'torch'

Returns:

Type Description
Tensor

A tensor of shape (N, 3), where N is the number of peaks, and each row contains the (Z, Y, X) coordinates of a peak.

Raises:

Type Description
TypeError

If the input volume cannot be converted to a torch.Tensor.

ValueError

If the input volume is not 3-dimensional after conversion.

Source code in src/torch_find_peaks/find_peaks.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def find_peaks_3d(
        volume: Union[torch.Tensor, np.ndarray], 
        min_distance: int = 1,
        threshold_abs: float = 0.0,
        exclude_border: int = 0,
        return_as: Literal["torch","numpy","dataframe"] = "torch",
) -> torch.Tensor:
    """
    Find local peaks in a 3D volume.

    Accepts various input types (torch.Tensor, numpy.ndarray) and attempts
    to convert them to torch.Tensor before processing.

    Parameters
    ----------
    volume : Any
        A 3D tensor-like object (e.g., torch.Tensor, numpy.ndarray)
        representing the input volume.
    min_distance : int, optional
        Minimum distance between peaks. Default is 1.
    threshold_abs : float, optional
        Minimum intensity value for a peak to be considered. Default is 0.0.
    exclude_border : int, optional
        Width of the border to exclude from peak detection. Default is 0.
    return_as : str, optional
        The format of the output. Default is "torch".
        Other options are "numpy" and "dataframe".

    Returns
    -------
    torch.Tensor
        A tensor of shape (N, 3), where N is the number of peaks, and each row
        contains the (Z, Y, X) coordinates of a peak.

    Raises
    ------
    TypeError
        If the input volume cannot be converted to a torch.Tensor.
    ValueError
        If the input volume is not 3-dimensional after conversion.
    """
    if isinstance(volume, torch.Tensor):
        volume_tensor = volume
    elif isinstance(volume, np.ndarray):
        volume_tensor = torch.from_numpy(volume)
    # Add checks for pandas/polars DataFrames/Series here if needed
    else:
        try:
            # Attempt a general conversion for other array-like objects
            volume_tensor = torch.as_tensor(volume)
        except Exception as e:
            raise TypeError(
                f"Input type {type(volume)} not supported or conversion failed: {e}"
            )

    if volume_tensor.ndim != 3:
        raise ValueError(
            f"Input volume must be 3-dimensional, but got {volume_tensor.ndim} dimensions."
        )

    found_peaks, heights = _find_peaks_3d_torch(
        volume=volume_tensor,
        min_distance=min_distance,
        threshold_abs=threshold_abs,
        exclude_border=exclude_border,
    )

    if return_as == "torch":
        return found_peaks, heights
    elif return_as == "numpy":
        return found_peaks.cpu().numpy(), heights.cpu().numpy()
    elif return_as == "dataframe":
        # Use einops.pack to properly handle tensors with different dimensions
        # First tensor has shape [N, 3], second has shape [N]
        # We're packing them along the second dimension (dim=1)
        packed, _ = einops.pack([found_peaks, heights], 'n *')
        return pd.DataFrame(packed.cpu(), columns=["z", "y", "x", "height"])
    else:
        raise ValueError(f"Invalid return_as value: {return_as}")

torch_find_peaks.refine_peaks

refine_peaks_2d(image, peak_coords, boxsize, max_iterations=1000, learning_rate=0.01, tolerance=1e-06, amplitude=1.0, sigma_x=1.0, sigma_y=1.0, return_as='torch')

Refine the positions of peaks in a 2D image by fitting 2D Gaussian functions.

Parameters:

Name Type Description Default
image Any

A 2D tensor-like object (e.g., torch.Tensor, numpy.ndarray) containing the image data.

required
peak_coords torch.Tensor, np.ndarray, or pd.DataFrame

A tensor-like object of shape (n, 2) containing the initial peak coordinates (y, x).

required
boxsize int

Size of the region to crop around each peak (must be even).

required
max_iterations int

Maximum number of optimization iterations. Default is 1000.

1000
learning_rate float

Learning rate for the optimizer. Default is 0.01.

0.01
tolerance float

Convergence tolerance for the optimization. Default is 1e-6.

1e-06
amplitude Union[Tensor, float]

Initial amplitude of the Gaussian. Default is 1.0.

1.0
sigma_x Union[Tensor, float]

Initial standard deviation in the x direction. Default is 1.0.

1.0
sigma_y Union[Tensor, float]

Initial standard deviation in the y direction. Default is 1.0.

1.0

Returns:

Type Description
Tensor

A tensor of shape (n, 5) containing the fitted parameters for each peak. Each row contains [amplitude, y, x, sigma_x, sigma_y].

Source code in src/torch_find_peaks/refine_peaks.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def refine_peaks_2d(
    image: Any,
    peak_coords: Union[torch.Tensor, np.ndarray, pd.DataFrame],
    boxsize: int,
    max_iterations: int = 1000,
    learning_rate: float = 0.01,
    tolerance: float = 1e-6,
    amplitude: Union[torch.Tensor, float] = 1.,
    sigma_x: Union[torch.Tensor, float] = 1.,
    sigma_y: Union[torch.Tensor, float] = 1.,
    return_as: Literal["torch", "numpy", "dataframe"] = "torch",
) -> torch.Tensor:
    """
    Refine the positions of peaks in a 2D image by fitting 2D Gaussian functions.

    Parameters
    ----------
    image : Any
        A 2D tensor-like object (e.g., torch.Tensor, numpy.ndarray)
        containing the image data.
    peak_coords : torch.Tensor, np.ndarray, or pd.DataFrame
        A tensor-like object of shape (n, 2) containing the initial peak coordinates (y, x).
    boxsize : int
        Size of the region to crop around each peak (must be even).
    max_iterations : int, optional
        Maximum number of optimization iterations. Default is 1000.
    learning_rate : float, optional
        Learning rate for the optimizer. Default is 0.01.
    tolerance : float, optional
        Convergence tolerance for the optimization. Default is 1e-6.
    amplitude : Union[torch.Tensor, float], optional
        Initial amplitude of the Gaussian. Default is 1.0.
    sigma_x : Union[torch.Tensor, float], optional
        Initial standard deviation in the x direction. Default is 1.0.
    sigma_y : Union[torch.Tensor, float], optional
        Initial standard deviation in the y direction. Default is 1.0.

    Returns
    -------
    torch.Tensor
        A tensor of shape (n, 5) containing the fitted parameters for each peak.
        Each row contains [amplitude, y, x, sigma_x, sigma_y].
    """
    if not isinstance(image, torch.Tensor):
        image = torch.as_tensor(image)
    if isinstance(peak_coords, pd.DataFrame):
        amplitude = torch.as_tensor(peak_coords["height"].to_numpy())
        peak_coords = torch.as_tensor(peak_coords[["y","x"]].to_numpy())
    if not isinstance(peak_coords, torch.Tensor):
        peak_coords = torch.as_tensor(peak_coords)

    num_peaks = peak_coords.shape[0]
    if not isinstance(amplitude, torch.Tensor):
        amplitude = torch.tensor([amplitude] * num_peaks, device=image.device)
    if not isinstance(sigma_x, torch.Tensor):
        sigma_x = torch.tensor([sigma_x] * num_peaks, device=image.device)
    if not isinstance(sigma_y, torch.Tensor):
        sigma_y = torch.tensor([sigma_y] * num_peaks, device=image.device)

    initial_peak_data = torch.stack([
        amplitude,
        peak_coords[:, 0],  # y
        peak_coords[:, 1],  # x
        sigma_x,
        sigma_y,
    ], dim=-1)

    refined_peak_data = _refine_peaks_2d_torch(
        image=image,
        peak_data=initial_peak_data,
        boxsize=boxsize,
        max_iterations=max_iterations,
        learning_rate=learning_rate,
        tolerance=tolerance,
    )

    if return_as=="torch":
        return refined_peak_data
    elif return_as=="numpy":
        return refined_peak_data.detach().cpu().numpy()
    elif return_as=="dataframe":
        return pd.DataFrame(refined_peak_data.detach().cpu().numpy(), columns=["amplitude", "y", "x", "sigma_x", "sigma_y"])
    else:
        raise ValueError(f"Invalid return_as value: {return_as}")

refine_peaks_3d(volume, peak_coords, boxsize, max_iterations=1000, learning_rate=0.01, tolerance=1e-06, amplitude=1.0, sigma_x=1.0, sigma_y=1.0, sigma_z=1.0, return_as='torch')

Refine the positions of peaks in a 3D volume by fitting 3D Gaussian functions.

Parameters:

Name Type Description Default
volume Any

A 3D tensor-like object (e.g., torch.Tensor, numpy.ndarray) containing the volume data.

required
peak_coords torch.Tensor, np.ndarray, or pd.DataFrame

A tensor-like object of shape (n, 3) containing the initial peak coordinates (z, y, x).

required
boxsize int

Size of the region to crop around each peak (must be even).

required
max_iterations int

Maximum number of optimization iterations. Default is 1000.

1000
learning_rate float

Learning rate for the optimizer. Default is 0.01.

0.01
tolerance float

Convergence tolerance for the optimization. Default is 1e-6.

1e-06
amplitude Union[Tensor, float]

Initial amplitude of the Gaussian. Default is 1.0.

1.0
sigma_x Union[Tensor, float]

Initial standard deviation in the x direction. Default is 1.0.

1.0
sigma_y Union[Tensor, float]

Initial standard deviation in the y direction. Default is 1.0.

1.0
sigma_z Union[Tensor, float]

Initial standard deviation in the z direction. Default is 1.0.

1.0

Returns:

Type Description
Tensor

A tensor of shape (n, 7) containing the fitted parameters for each peak. Each row contains [amplitude, z, y, x, sigma_x, sigma_y, sigma_z].

Source code in src/torch_find_peaks/refine_peaks.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def refine_peaks_3d(
    volume: Any,
    peak_coords: Union[torch.Tensor, np.ndarray, pd.DataFrame],
    boxsize: int,
    max_iterations: int = 1000,
    learning_rate: float = 0.01,
    tolerance: float = 1e-6,
    amplitude: Union[torch.Tensor, float] = 1.,
    sigma_x: Union[torch.Tensor, float] = 1.,
    sigma_y: Union[torch.Tensor, float] = 1.,
    sigma_z: Union[torch.Tensor, float] = 1.,
    return_as: Literal["torch", "numpy", "dataframe"] = "torch",
) -> torch.Tensor:
    """
    Refine the positions of peaks in a 3D volume by fitting 3D Gaussian functions.

    Parameters
    ----------
    volume : Any
        A 3D tensor-like object (e.g., torch.Tensor, numpy.ndarray)
        containing the volume data.
    peak_coords : torch.Tensor, np.ndarray, or pd.DataFrame
        A tensor-like object of shape (n, 3) containing the initial peak coordinates (z, y, x).
    boxsize : int
        Size of the region to crop around each peak (must be even).
    max_iterations : int, optional
        Maximum number of optimization iterations. Default is 1000.
    learning_rate : float, optional
        Learning rate for the optimizer. Default is 0.01.
    tolerance : float, optional
        Convergence tolerance for the optimization. Default is 1e-6.
    amplitude : Union[torch.Tensor, float], optional
        Initial amplitude of the Gaussian. Default is 1.0.
    sigma_x : Union[torch.Tensor, float], optional
        Initial standard deviation in the x direction. Default is 1.0.
    sigma_y : Union[torch.Tensor, float], optional
        Initial standard deviation in the y direction. Default is 1.0.
    sigma_z : Union[torch.Tensor, float], optional
        Initial standard deviation in the z direction. Default is 1.0.

    Returns
    -------
    torch.Tensor
        A tensor of shape (n, 7) containing the fitted parameters for each peak.
        Each row contains [amplitude, z, y, x, sigma_x, sigma_y, sigma_z].
    """
    if not isinstance(volume, torch.Tensor):
        volume = torch.as_tensor(volume)
    if isinstance(peak_coords, pd.DataFrame):
        amplitude = torch.as_tensor(peak_coords["height"].to_numpy(),device=volume.device)
        peak_coords = torch.as_tensor(peak_coords[["z", "y", "x"]].to_numpy(),device=volume.device)
    if not isinstance(peak_coords, torch.Tensor):
        peak_coords = torch.as_tensor(peak_coords)

    num_peaks = peak_coords.shape[0]
    if not isinstance(amplitude, torch.Tensor):
        amplitude = torch.tensor([amplitude] * num_peaks, device=volume.device)
    if not isinstance(sigma_x, torch.Tensor):
        sigma_x = torch.tensor([sigma_x] * num_peaks, device=volume.device)
    if not isinstance(sigma_y, torch.Tensor):
        sigma_y = torch.tensor([sigma_y] * num_peaks, device=volume.device)
    if not isinstance(sigma_z, torch.Tensor):
        sigma_z = torch.tensor([sigma_z] * num_peaks, device=volume.device)

    initial_peak_data = torch.stack([
        amplitude,
        peak_coords[:, 0],  # z
        peak_coords[:, 1],  # y
        peak_coords[:, 2],  # x
        sigma_x,
        sigma_y,
        sigma_z,
    ], dim=-1)

    refined_peak_data, boxes, output = _refine_peaks_3d_torch(
        volume=volume,
        peak_data=initial_peak_data,
        boxsize=boxsize,
        max_iterations=max_iterations,
        learning_rate=learning_rate,
        tolerance=tolerance,
    )

    if return_as == "torch":
        return refined_peak_data
    elif return_as == "numpy":
        return refined_peak_data.detach().cpu().numpy()
    elif return_as == "dataframe":
        return pd.DataFrame(refined_peak_data.detach().cpu().numpy(), columns=["amplitude", "z", "y", "x", "sigma_x", "sigma_y", "sigma_z"])
    elif return_as == "diagnostic":
        # Return the boxes and output for diagnostic purposes
        return {
            "refined_peaks": pd.DataFrame(refined_peak_data.detach().cpu().numpy(), columns=["amplitude", "z", "y", "x", "sigma_x", "sigma_y", "sigma_z"]),
            "boxes": boxes,
            "output": output
        }
    else:
        raise ValueError(f"Invalid return_as value: {return_as}")

torch_find_peaks.gaussians

Gaussian2D

Bases: Module

A 2D Gaussian function.

Parameters:

Name Type Description Default
amplitude tensor

Amplitude of the Gaussian. Default is torch.tensor([1.0]).

1.0
center_y tensor

Y-coordinate of the center. Default is torch.tensor([0.0]).

0.0
center_x tensor

X-coordinate of the center. Default is torch.tensor([0.0]).

0.0
sigma_y tensor

Standard deviation along the y-axis. Default is torch.tensor([1.0]).

1.0
sigma_x tensor

Standard deviation along the x-axis. Default is torch.tensor([1.0]).

1.0

Methods:

Name Description
forward

Compute the Gaussian values for a given 2D grid. Expects grid in yx order (grid[..., 0] is y, grid[..., 1] is x).

Source code in src/torch_find_peaks/gaussians.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class Gaussian2D(nn.Module):
    """
    A 2D Gaussian function.

    Parameters
    ----------
    amplitude : torch.tensor, optional
        Amplitude of the Gaussian. Default is torch.tensor([1.0]).
    center_y : torch.tensor, optional
        Y-coordinate of the center. Default is torch.tensor([0.0]).
    center_x : torch.tensor, optional
        X-coordinate of the center. Default is torch.tensor([0.0]).
    sigma_y : torch.tensor, optional
        Standard deviation along the y-axis. Default is torch.tensor([1.0]).
    sigma_x : torch.tensor, optional
        Standard deviation along the x-axis. Default is torch.tensor([1.0]).

    Methods
    -------
    forward(grid)
        Compute the Gaussian values for a given 2D grid.
        Expects grid in yx order (grid[..., 0] is y, grid[..., 1] is x).
    """

    def __init__(self,
                 amplitude: Union[torch.Tensor | float] = 1.0,
                 center_y: Union[torch.Tensor | float] = 0.0,
                 center_x: Union[torch.Tensor | float] = 0.0,
                 sigma_y: Union[torch.Tensor | float] = 1.0,
                 sigma_x: Union[torch.Tensor | float] = 1.0
    ):
        super(Gaussian2D, self).__init__()
        # Ensure that the parameters are tensors
        if not isinstance(amplitude, torch.Tensor):
            amplitude = torch.tensor(amplitude)
        if not isinstance(center_y, torch.Tensor):
            center_y = torch.tensor(center_y)
        if not isinstance(center_x, torch.Tensor):
            center_x = torch.tensor(center_x)
        if not isinstance(sigma_y, torch.Tensor):
            sigma_y = torch.tensor(sigma_y)
        if not isinstance(sigma_x, torch.Tensor):
            sigma_x = torch.tensor(sigma_x)
        # Check if all parameters are of the same shape
        assert amplitude.shape == center_y.shape == center_x.shape == sigma_y.shape == sigma_x.shape, \
            "All parameters must have the same shape."

        self.amplitude = nn.Parameter(amplitude)
        self.center_y = nn.Parameter(center_y)
        self.center_x = nn.Parameter(center_x)
        self.sigma_y = nn.Parameter(sigma_y)
        self.sigma_x = nn.Parameter(sigma_x)

    def forward(self, grid):
        """
        Forward pass for 2D Gaussian list.

        Args:
            grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

        Returns
        -------
            Tensor of Gaussian values
        """
        # Add batch dimension
        grid_x = einops.rearrange(grid[..., 1], 'h w -> h w' + ' 1'*self.amplitude.dim())
        grid_y = einops.rearrange(grid[..., 0], 'h w -> h w' + ' 1'*self.amplitude.dim())

        amplitude = einops.rearrange(self.amplitude, '... -> 1 1 ...')
        center_x = einops.rearrange(self.center_x, '... -> 1 1 ...')
        center_y = einops.rearrange(self.center_y, '... -> 1 1 ...')
        sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 ...')
        sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 ...')

        gaussian = amplitude * torch.exp(
            -((grid_x - center_x) ** 2 / (2 * sigma_x ** 2) +
              (grid_y - center_y) ** 2 / (2 * sigma_y ** 2))
        )

        return einops.rearrange(gaussian, 'h w ... -> ... h w')

forward(grid)

Forward pass for 2D Gaussian list.

Args: grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

Returns:

Type Description
Tensor of Gaussian values
Source code in src/torch_find_peaks/gaussians.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def forward(self, grid):
    """
    Forward pass for 2D Gaussian list.

    Args:
        grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

    Returns
    -------
        Tensor of Gaussian values
    """
    # Add batch dimension
    grid_x = einops.rearrange(grid[..., 1], 'h w -> h w' + ' 1'*self.amplitude.dim())
    grid_y = einops.rearrange(grid[..., 0], 'h w -> h w' + ' 1'*self.amplitude.dim())

    amplitude = einops.rearrange(self.amplitude, '... -> 1 1 ...')
    center_x = einops.rearrange(self.center_x, '... -> 1 1 ...')
    center_y = einops.rearrange(self.center_y, '... -> 1 1 ...')
    sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 ...')
    sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 ...')

    gaussian = amplitude * torch.exp(
        -((grid_x - center_x) ** 2 / (2 * sigma_x ** 2) +
          (grid_y - center_y) ** 2 / (2 * sigma_y ** 2))
    )

    return einops.rearrange(gaussian, 'h w ... -> ... h w')

Gaussian3D

Bases: Module

A 3D Gaussian function.

Parameters:

Name Type Description Default
amplitude tensor

Amplitude of the Gaussian. Default is torch.tensor([1.0]).

1.0
center_z tensor

Z-coordinate of the center. Default is torch.tensor([0.0]).

0.0
center_y tensor

Y-coordinate of the center. Default is torch.tensor([0.0]).

0.0
center_x tensor

X-coordinate of the center. Default is torch.tensor([0.0]).

0.0
sigma_z tensor

Standard deviation along the z-axis. Default is torch.tensor([1.0]).

1.0
sigma_y tensor

Standard deviation along the y-axis. Default is torch.tensor([1.0]).

1.0
sigma_x tensor

Standard deviation along the x-axis. Default is torch.tensor([1.0]).

1.0

Methods:

Name Description
forward

Compute the Gaussian values for a given 3D grid. Expects grid in zyx order (grid[..., 0] is z, grid[..., 1] is y, grid[..., 2] is x).

Source code in src/torch_find_peaks/gaussians.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class Gaussian3D(nn.Module):
    """
    A 3D Gaussian function.

    Parameters
    ----------
    amplitude : torch.tensor, optional
        Amplitude of the Gaussian. Default is torch.tensor([1.0]).
    center_z : torch.tensor, optional
        Z-coordinate of the center. Default is torch.tensor([0.0]).
    center_y : torch.tensor, optional
        Y-coordinate of the center. Default is torch.tensor([0.0]).
    center_x : torch.tensor, optional
        X-coordinate of the center. Default is torch.tensor([0.0]).
    sigma_z : torch.tensor, optional
        Standard deviation along the z-axis. Default is torch.tensor([1.0]).
    sigma_y : torch.tensor, optional
        Standard deviation along the y-axis. Default is torch.tensor([1.0]).
    sigma_x : torch.tensor, optional
        Standard deviation along the x-axis. Default is torch.tensor([1.0]).

    Methods
    -------
    forward(grid)
        Compute the Gaussian values for a given 3D grid.
        Expects grid in zyx order (grid[..., 0] is z, grid[..., 1] is y, grid[..., 2] is x).
    """

    def __init__(self,
                 amplitude: Union[torch.Tensor | float] = 1.0,
                 center_z: Union[torch.Tensor | float] = 0.0,
                 center_y: Union[torch.Tensor | float] = 0.0,
                 center_x: Union[torch.Tensor | float] = 0.0,
                 sigma_z: Union[torch.Tensor | float] = 1.0,
                 sigma_y: Union[torch.Tensor | float] = 1.0,
                 sigma_x: Union[torch.Tensor | float] = 1.0
    ):
        super(Gaussian3D, self).__init__()
        # Ensure that the parameters are tensors
        if not isinstance(amplitude, torch.Tensor):
            amplitude = torch.tensor(amplitude)
        if not isinstance(center_z, torch.Tensor):
            center_z = torch.tensor(center_z)
        if not isinstance(center_y, torch.Tensor):
            center_y = torch.tensor(center_y)
        if not isinstance(center_x, torch.Tensor):
            center_x = torch.tensor(center_x)
        if not isinstance(sigma_z, torch.Tensor):
            sigma_z = torch.tensor(sigma_z)
        if not isinstance(sigma_y, torch.Tensor):
            sigma_y = torch.tensor(sigma_y)
        if not isinstance(sigma_x, torch.Tensor):
            sigma_x = torch.tensor(sigma_x)
        # Check if all parameters are of the same shape
        assert amplitude.shape == center_z.shape == center_y.shape == center_x.shape == sigma_z.shape == sigma_y.shape == sigma_x.shape, \
            "All parameters must have the same shape."

        self.amplitude = nn.Parameter(amplitude)
        self.center_z = nn.Parameter(center_z)
        self.center_y = nn.Parameter(center_y)
        self.center_x = nn.Parameter(center_x)
        self.sigma_z = nn.Parameter(sigma_z)
        self.sigma_y = nn.Parameter(sigma_y)
        self.sigma_x = nn.Parameter(sigma_x)

    def forward(self, grid):
        """
        Forward pass for 3D Gaussian list.

        Args:
            grid: Tensor of shape (d, h, w, 3) containing 3D coordinates in zyx order 
                 (grid[..., 0] is z, grid[..., 1] is y, grid[..., 2] is x).

        Returns
        -------
            Tensor of Gaussian values
        """
         # Add batch dimension
        grid_x = einops.rearrange(grid[..., 2], 'd h w -> d h w' + ' 1'*self.amplitude.dim())
        grid_y = einops.rearrange(grid[..., 1], 'd h w -> d h w' + ' 1'*self.amplitude.dim())
        grid_z = einops.rearrange(grid[..., 0], 'd h w -> d h w' + ' 1'*self.amplitude.dim())

        amplitude = einops.rearrange(self.amplitude, '... -> 1 1 1 ...')
        center_x = einops.rearrange(self.center_x, '... -> 1 1 1 ...')
        center_y = einops.rearrange(self.center_y, '... -> 1 1 1 ...')
        center_z = einops.rearrange(self.center_z, '... -> 1 1 1 ...')
        sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 1 ...')
        sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 1 ...')
        sigma_z = einops.rearrange(self.sigma_z, '... -> 1 1 1 ...')


        gaussian = amplitude * torch.exp(
            -((grid_x - center_x) ** 2 / (2 * sigma_x ** 2) +
              (grid_y - center_y) ** 2 / (2 * sigma_y ** 2) +
              (grid_z - center_z) ** 2 / (2 * sigma_z ** 2))
        )

        return einops.rearrange(gaussian, 'd h w ... -> ... d h w')

forward(grid)

Forward pass for 3D Gaussian list.

Args: grid: Tensor of shape (d, h, w, 3) containing 3D coordinates in zyx order (grid[..., 0] is z, grid[..., 1] is y, grid[..., 2] is x).

Returns:

Type Description
Tensor of Gaussian values
Source code in src/torch_find_peaks/gaussians.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def forward(self, grid):
    """
    Forward pass for 3D Gaussian list.

    Args:
        grid: Tensor of shape (d, h, w, 3) containing 3D coordinates in zyx order 
             (grid[..., 0] is z, grid[..., 1] is y, grid[..., 2] is x).

    Returns
    -------
        Tensor of Gaussian values
    """
     # Add batch dimension
    grid_x = einops.rearrange(grid[..., 2], 'd h w -> d h w' + ' 1'*self.amplitude.dim())
    grid_y = einops.rearrange(grid[..., 1], 'd h w -> d h w' + ' 1'*self.amplitude.dim())
    grid_z = einops.rearrange(grid[..., 0], 'd h w -> d h w' + ' 1'*self.amplitude.dim())

    amplitude = einops.rearrange(self.amplitude, '... -> 1 1 1 ...')
    center_x = einops.rearrange(self.center_x, '... -> 1 1 1 ...')
    center_y = einops.rearrange(self.center_y, '... -> 1 1 1 ...')
    center_z = einops.rearrange(self.center_z, '... -> 1 1 1 ...')
    sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 1 ...')
    sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 1 ...')
    sigma_z = einops.rearrange(self.sigma_z, '... -> 1 1 1 ...')


    gaussian = amplitude * torch.exp(
        -((grid_x - center_x) ** 2 / (2 * sigma_x ** 2) +
          (grid_y - center_y) ** 2 / (2 * sigma_y ** 2) +
          (grid_z - center_z) ** 2 / (2 * sigma_z ** 2))
    )

    return einops.rearrange(gaussian, 'd h w ... -> ... d h w')

WarpedGaussian2D

Bases: Module

A 2D warped Gaussian function.

Parameters:

Name Type Description Default
amplitude tensor

Amplitude of the Gaussian. Default is torch.tensor([1.0]).

tensor([1.0])
center_y tensor

Y-coordinate of the center. Default is torch.tensor([0.0]).

tensor([0.0])
center_x tensor

X-coordinate of the center. Default is torch.tensor([0.0]).

tensor([0.0])
sigma_y tensor

Standard deviation along the y-axis. Default is torch.tensor([1.0]).

tensor([1.0])
sigma_x tensor

Standard deviation along the x-axis. Default is torch.tensor([1.0]).

tensor([1.0])
warp tensor

Warp factor for the Gaussian. Default is torch.tensor([1.0]).

tensor([1.0])
warp_angle tensor

Angle of the warp in radians. Default is torch.tensor([0.0]).

tensor([0.0])

Methods:

Name Description
forward

Compute the warped Gaussian values for a given 2D grid. Expects grid in yx order (grid[..., 0] is y, grid[..., 1] is x).

Source code in src/torch_find_peaks/gaussians.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class WarpedGaussian2D(nn.Module):
    """
    A 2D warped Gaussian function.

    Parameters
    ----------
    amplitude : torch.tensor, optional
        Amplitude of the Gaussian. Default is torch.tensor([1.0]).
    center_y : torch.tensor, optional
        Y-coordinate of the center. Default is torch.tensor([0.0]).
    center_x : torch.tensor, optional
        X-coordinate of the center. Default is torch.tensor([0.0]).
    sigma_y : torch.tensor, optional
        Standard deviation along the y-axis. Default is torch.tensor([1.0]).
    sigma_x : torch.tensor, optional
        Standard deviation along the x-axis. Default is torch.tensor([1.0]).
    warp : torch.tensor, optional
        Warp factor for the Gaussian. Default is torch.tensor([1.0]).
    warp_angle : torch.tensor, optional
        Angle of the warp in radians. Default is torch.tensor([0.0]).

    Methods
    -------
    forward(grid)
        Compute the warped Gaussian values for a given 2D grid.
        Expects grid in yx order (grid[..., 0] is y, grid[..., 1] is x).
    """

    def __init__(self,
                 amplitude: torch.tensor = torch.tensor([1.0]),
                 center_y: torch.tensor = torch.tensor([0.0]),
                 center_x: torch.tensor = torch.tensor([0.0]),
                 sigma_y: torch.tensor = torch.tensor([1.0]),
                 sigma_x: torch.tensor = torch.tensor([1.0]),
                 warp: torch.tensor = torch.tensor([1.0]),
                 warp_angle: torch.tensor = torch.tensor([0.0])
    ):
        super(WarpedGaussian2D, self).__init__()
        # Ensure that the parameters are tensors
        if not isinstance(amplitude, torch.Tensor):
            amplitude = torch.tensor(amplitude)
        if not isinstance(center_y, torch.Tensor):
            center_y = torch.tensor(center_y)
        if not isinstance(center_x, torch.Tensor):
            center_x = torch.tensor(center_x)
        if not isinstance(sigma_y, torch.Tensor):
            sigma_y = torch.tensor(sigma_y)
        if not isinstance(sigma_x, torch.Tensor):
            sigma_x = torch.tensor(sigma_x)
        if not isinstance(warp, torch.Tensor):
            warp = torch.tensor(warp)
        if not isinstance(warp_angle, torch.Tensor):
            warp_angle = torch.tensor(warp_angle)
        # Check if all parameters are of the same shape
        assert amplitude.shape == center_y.shape == center_x.shape == sigma_y.shape == sigma_x.shape == warp.shape == warp_angle.shape, \
            "All parameters must have the same shape."

        self.amplitude = nn.Parameter(amplitude)
        self.center_y = nn.Parameter(center_y)
        self.center_x = nn.Parameter(center_x)
        self.sigma_y = nn.Parameter(sigma_y)
        self.sigma_x = nn.Parameter(sigma_x)
        self.warp = nn.Parameter(warp)
        self.warp_angle = nn.Parameter(warp_angle)

    def forward(self, grid):
        """
        Forward pass for 2D warped Gaussian list.

        Args:
            grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

        Returns
        -------
            Tensor of warped Gaussian values
        """
        amplitude = einops.rearrange(self.amplitude, '... -> 1 1 ...')
        center_x = einops.rearrange(self.center_x, '... -> 1 1 ...')
        center_y = einops.rearrange(self.center_y, '... -> 1 1 ...')
        sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 ...')
        sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 ...')
        warp = einops.rearrange(self.warp, '... -> 1 1 ...')
        warp_angle = einops.rearrange(self.warp_angle, '... -> 1 1 ...')

        grid_x = einops.rearrange(grid[..., 1], 'h w -> h w 1')
        grid_y = einops.rearrange(grid[..., 0], 'h w -> h w 1')

        u = (grid_x - center_x) * torch.cos(warp_angle) - (grid_y - center_y) * torch.sin(warp_angle)
        v = (grid_x - center_x) * torch.sin(warp_angle) + (grid_y - center_y) * torch.cos(warp_angle)

        warped_gaussian = amplitude * torch.exp(
            -((u - warp * v ** 2) ** 2 / (2 * sigma_x ** 2) +
              v ** 2 / (2 * sigma_y ** 2))
        )

        return einops.rearrange(warped_gaussian, 'h w ... -> ... h w')

forward(grid)

Forward pass for 2D warped Gaussian list.

Args: grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

Returns:

Type Description
Tensor of warped Gaussian values
Source code in src/torch_find_peaks/gaussians.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def forward(self, grid):
    """
    Forward pass for 2D warped Gaussian list.

    Args:
        grid: Tensor of shape (h,w, 2) containing 2D coordinates in yx order.

    Returns
    -------
        Tensor of warped Gaussian values
    """
    amplitude = einops.rearrange(self.amplitude, '... -> 1 1 ...')
    center_x = einops.rearrange(self.center_x, '... -> 1 1 ...')
    center_y = einops.rearrange(self.center_y, '... -> 1 1 ...')
    sigma_x = einops.rearrange(self.sigma_x, '... -> 1 1 ...')
    sigma_y = einops.rearrange(self.sigma_y, '... -> 1 1 ...')
    warp = einops.rearrange(self.warp, '... -> 1 1 ...')
    warp_angle = einops.rearrange(self.warp_angle, '... -> 1 1 ...')

    grid_x = einops.rearrange(grid[..., 1], 'h w -> h w 1')
    grid_y = einops.rearrange(grid[..., 0], 'h w -> h w 1')

    u = (grid_x - center_x) * torch.cos(warp_angle) - (grid_y - center_y) * torch.sin(warp_angle)
    v = (grid_x - center_x) * torch.sin(warp_angle) + (grid_y - center_y) * torch.cos(warp_angle)

    warped_gaussian = amplitude * torch.exp(
        -((u - warp * v ** 2) ** 2 / (2 * sigma_x ** 2) +
          v ** 2 / (2 * sigma_y ** 2))
    )

    return einops.rearrange(warped_gaussian, 'h w ... -> ... h w')