# Copyright © 2025 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE
from typing import Union, Literal, Callable, Optional, Sequence, TYPE_CHECKING
import math
import gc
import torch
from torch import Tensor
from torchvision.transforms import GaussianBlur
import numpy as np
from numpy import ndarray
import ptychi.maths as pmath
import ptychi.propagate as propagate
from ptychi.timing.timer_utils import timer
from ptychi.device import AcceleratorModuleWrapper
if TYPE_CHECKING:
from ptychi.api.task import PtychographyTask
_default_complex_dtype = torch.complex64
[docs]
def get_suggested_object_size(positions_px, probe_shape, extra=0):
h = np.ceil(positions_px[:, 0].max() - positions_px[:, 0].min()) + probe_shape[0] + extra
w = np.ceil(positions_px[:, 1].max() - positions_px[:, 1].min()) + probe_shape[1] + extra
return (int(h), int(w))
[docs]
def rescale_probe(
probe: Union[ndarray, Tensor],
patterns: Union[ndarray, Tensor],
weights: Optional[Union[ndarray, Tensor]] = None,
) -> None:
"""
Scale probe so that the sum of intensity matches that of the diffraction patterns.
Parameters
----------
probe : Tensor
A (n_modes, h, w) or (n_opr_modes, n_modes, h, w) tensor of the probe.
patterns : Tensor
A (n, h, w) tensor of diffraction patterns.
weights : Tensor, optional
A (n_points, n_opr_modes) tensor of weights for each OPR mode.
Returns
-------
scaled_probe : Tensor
The scaled probe.
"""
propagator = propagate.FourierPropagator()
probe_tensor = torch.tensor(probe)
if probe_tensor.ndim == 4:
if probe_tensor.shape[0] == 1 or weights is None:
probe_tensor = probe_tensor[0]
if probe_tensor.ndim == 3:
i_probe = (
(torch.abs(propagator.propagate_forward(probe_tensor)) ** 2)
.sum()
.detach()
.cpu()
.numpy()
)
else:
weights = torch.tensor(weights)
weights = weights.mean(dim=0)
probe_corrected = (probe_tensor * weights[:, None, None, None]).sum(0)
i_probe = (
(torch.abs(propagator.propagate_forward(probe_corrected)) ** 2)
.sum()
.detach()
.cpu()
.numpy()
)
patterns = to_numpy(patterns)
i_data = np.sum(np.mean(patterns, axis=0))
factor = i_data / i_probe
probe = probe * np.sqrt(factor)
return probe
[docs]
def orthogonalize_initial_probe(
probe: Tensor,
secondary_mode_energy: float = 0.02,
method: Literal["hermite"] = "hermite"
) -> Tensor:
"""
Orthogonalize initial probe.
Parameters
----------
probe : Tensor
A (n_opr_modes, n_modes, h, w) tensor of the probe. This function only generates
incoherent modes; OPR modes are kept as they are. Only the first incoherent mode
of the input probe is used. As such, the rest of the incoherent modes can be
arbotrarily initialized, but the shape of the input probe should be indicate the
number of incoherent modes intended to be generated.
secondary_mode_energy : float, optional
The energy of the secondary mode relative to the principal mode, which is
always 1.0.
method: Literal["hermite"], optional
The method to use for orthogonalization.
Returns
-------
Tensor
The orthogonalized probe.
"""
n_modes = probe.shape[1]
mode_energies = torch.zeros(n_modes)
mode_energies[1:] = secondary_mode_energy
mode_energies[0] = 1.0 - mode_energies.sum()
e_total = torch.sum(torch.abs(probe[0, 0]) ** 2)
mode_energies = mode_energies * e_total
if method == "hermite":
m = math.ceil(math.sqrt(n_modes)) - 1
n = math.ceil(n_modes / (m + 1)) - 1
h = generate_secondary_probe_modes_hermite(probe[0, 0], m, n)
probe[0, 1:, :, :] = h[1:n_modes, :, :]
else:
raise ValueError(f"Unknown orthogonalization method: {method}")
# Normalize.
for i_mode in range(n_modes):
probe[0, i_mode] = probe[0, i_mode] * torch.sqrt(mode_energies[i_mode] / torch.sum(torch.abs(probe[0, i_mode] ** 2)))
return probe
[docs]
@timer()
def generate_secondary_probe_modes_hermite(probe: Tensor, m: int, n: int) -> Tensor:
"""
Generate secondary probe modes using Hermite polynomials.
Parameters
----------
probe : Tensor
A (h, w) tensor of the primary mode of the probe.
m, n : int
The orders of the Hermite polynomial.
Returns
-------
Tensor
A ((m + 1) * (n + 1), h, w) tensor of the secondary probe modes.
"""
x = torch.arange(probe.shape[-1]) - probe.shape[-1] / 2 + 1
y = torch.arange(probe.shape[-2]) - probe.shape[-2] / 2 + 1
xx, yy = torch.meshgrid(x, y, indexing="xy")
cenx = torch.sum(xx * torch.abs(probe) ** 2) / torch.sum(torch.abs(probe) ** 2)
ceny = torch.sum(yy * torch.abs(probe) ** 2) / torch.sum(torch.abs(probe) ** 2)
varx = torch.sum((xx - cenx) ** 2 * torch.abs(probe) ** 2) / torch.sum(torch.abs(probe) ** 2)
vary = torch.sum((yy - ceny) ** 2 * torch.abs(probe) ** 2) / torch.sum(torch.abs(probe) ** 2)
counter = 0
h = torch.empty([(m + 1) * (n + 1), *probe.shape], dtype=probe.dtype)
for nii in range(n + 1):
for mii in range(m + 1):
auxfunc = ((xx - cenx) ** mii) * ((yy - ceny) ** nii) * probe
if counter > 0:
auxfunc = auxfunc * torch.exp(-((xx - cenx) ** 2 / (2 * varx)) - ((yy - ceny) ** 2 / (2 * vary)))
auxfunc = auxfunc / torch.sqrt(torch.sum(torch.abs(auxfunc) ** 2))
# Orthogonalize the current mode to the previous ones.
for ii in range(counter):
auxfunc = auxfunc - h[ii] * torch.sum(h[ii] * auxfunc.conj(), dim=(-1, -2))
auxfunc = auxfunc / torch.sqrt(torch.sum(torch.abs(auxfunc) ** 2))
h[counter] = auxfunc
counter += 1
return h
[docs]
@timer()
def get_probe_renormalization_factor(patterns: Tensor | ndarray) -> float:
"""
Calculate the renormalization factor that should be applied to the probe
to match the maximum power of the diffraction patterns.
Parameters
----------
patterns : Tensor | ndarray
A (n, h, w) buffer of diffraction patterns.
Returns
-------
float
The renormalization factor.
"""
if isinstance(patterns, Tensor):
patterns = patterns.detach().cpu().numpy()
max_power = np.max(np.sum((patterns), axis=(1, 2))) / (patterns[0].size)
return np.sqrt(1 / max_power)
[docs]
def generate_initial_object(shape: tuple[int, ...], method: Literal["random"] = "random") -> Tensor:
if method == "random":
obj_mag = generate_gaussian_random_image(shape, loc=0.98, sigma=0.02, smoothing=3.0)
obj_mag = obj_mag.clamp(0.0, 1.0)
obj_phase = generate_gaussian_random_image(shape, loc=0.0, sigma=0.02, smoothing=3.0)
obj_phase = obj_phase.clamp(-torch.pi, torch.pi)
obj = obj_mag * torch.exp(1j * obj_phase)
else:
raise ValueError(f"Unknown object initialization method: {method}")
obj = obj.type(get_default_complex_dtype())
return obj
[docs]
@timer()
def add_additional_opr_probe_modes_to_probe(
probe: Tensor, n_opr_modes_to_add: int, normalize: bool = True
) -> Tensor:
"""
Add additional OPR modes to the probe.
Parameters
----------
probe : Tensor
A (n_opr_modes, n_modes, h, w) tensor of the probe.
n_opr_modes_to_add : int
The number of OPR modes to add.
normalize : bool, optional
Whether to normalize the OPR modes using `mnorm` so that the power
of each mode is the number of pixels in a mode.
Returns
-------
Tensor
A (n_opr_modes + n_opr_modes_to_add, n_modes, h, w) tensor of the probe
with additional OPR modes.
"""
if probe.ndim != 4:
raise ValueError("probe must be a (n_opr_modes, n_modes, h, w) tensor.")
n_modes = probe.shape[1]
opr_modes = torch.empty(
[n_opr_modes_to_add, n_modes, probe.shape[-2], probe.shape[-1]],
dtype=get_default_complex_dtype(),
)
for i in range(n_opr_modes_to_add):
for j in range(n_modes):
real = generate_gaussian_random_image(
probe.shape[-2:], loc=0, sigma=1, smoothing=0
)
imag = generate_gaussian_random_image(
probe.shape[-2:], loc=0, sigma=1, smoothing=0
)
opr_mode = real + 1j * imag
opr_modes[i, j, ...] = opr_mode
probe = torch.cat([probe, opr_modes], dim=0)
if normalize:
pnorm = pmath.mnorm(probe, dim=(-2, -1), keepdims=True)
probe[1:] = probe[1:] / pnorm[1:, :]
return probe
[docs]
@timer()
def generate_initial_opr_mode_weights(
n_points: int, n_opr_modes: int, eigenmode_weight: Optional[float] = None, probe: Optional[Tensor] = None
) -> Tensor:
"""
Generate initial weights for OPR modes, where the weights of the main OPR mode are set to 1,
and the weights of eigenmodes are set to 0.
Parameters
----------
n_points : int
number of scan points.
n_opr_modes : int
number of OPR modes.
eigenmode_weight : float
initial weight for eigenmodes.
probe: Tensor
The probe. If provided, the weights will be normalized to match the power of the probe.
Returns
-------
weights : Tensor
a (n_points, n_opr_modes) tensor of weights.
"""
if eigenmode_weight is None:
eigenmode_weights = torch.randn([n_points, n_opr_modes - 1]) * 1e-6
else:
eigenmode_weights = torch.full([n_points, n_opr_modes - 1], eigenmode_weight)
weights = torch.cat(
[torch.ones([n_points, 1]), eigenmode_weights],
dim=1,
)
if probe is not None:
pnorm = pmath.mnorm(probe, dim=(-2, -1), keepdims=False)
weights[:, 1:] = weights[:, 1:] / torch.mean(pnorm[1:], dim=1)
return weights
[docs]
@timer()
def generate_gaussian_random_image(
shape: tuple[int, ...], loc: float = 0.9, sigma: float = 0.1, smoothing: float = 3.0
) -> Tensor:
img = torch.randn(shape, dtype=torch.get_default_dtype()) * sigma + loc
if smoothing > 0.0:
img = GaussianBlur(kernel_size=(9, 9), sigma=(3, 3))(img[None, None, :, :])
img = img[0, 0, ...]
return img
[docs]
def to_tensor(data: Union[ndarray, Tensor], device=None, dtype=None) -> Tensor:
if device is None:
device = torch.get_default_device()
if isinstance(data, (np.ndarray, list, tuple)):
data = torch.tensor(data, device=device)
if dtype is None:
if data.dtype.is_complex:
dtype = get_default_complex_dtype()
elif not data.dtype.is_complex:
dtype = torch.get_default_dtype()
if data.dtype != dtype:
data = data.type(dtype)
if str(data.device) != str(device):
data = data.to(device)
return data
[docs]
def move_nested_tensors_to_device(value, device):
"""Recursively move tensors contained in lists/tuples/dicts to a device."""
if torch.is_tensor(value):
return value.to(device)
if isinstance(value, list):
return [move_nested_tensors_to_device(v, device) for v in value]
if isinstance(value, tuple):
return tuple(move_nested_tensors_to_device(v, device) for v in value)
if isinstance(value, dict):
return {k: move_nested_tensors_to_device(v, device) for k, v in value.items()}
return value
[docs]
def to_numpy(data: Union[ndarray, Tensor]) -> ndarray:
if isinstance(data, Tensor):
data = data.detach().cpu().numpy()
return data
[docs]
def set_default_complex_dtype(dtype):
"""Set the default complex dtype.
Parameters
----------
dtype : torch.dtype
The default complex dtype.
"""
global _default_complex_dtype
_default_complex_dtype = dtype
[docs]
def get_default_complex_dtype():
"""Get the default complex dtype.
Returns
-------
torch.dtype
The default complex dtype.
"""
return _default_complex_dtype
[docs]
def chunked_processing(
func: Callable,
common_kwargs: dict,
chunkable_kwargs: dict,
iterated_kwargs: dict,
replicated_kwargs: dict = None,
chunk_size: int = 96,
):
"""
Parameters
----------
func : callable
The callable to be executed.
common_kwargs : dict
A dictionary of arguments that should stay constant across chunks.
chunkable_kwargs : dict
A dictionary of arguments that should be chunked.
iterated_kwargs : dict
A dictionary of arguments that should be returned by `func`, then passed to `func`
for the next chunk. The order of arguments should be the same as the returns of
`func`.
replicated_kwargs : dict, optional
A dictionary of arguments that should be replicated for each chunk along the
first dimension to match the chunk size. Tensors given here should have a first
dimension of size 1 intended as the batch dimension.
chunk_size : int, optional
The size of each chunk. Default is 96.
Returns
-------
The returns of `func` as if it is executed for the entire data.
"""
full_batch_size = tuple(chunkable_kwargs.values())[0].shape[0]
for key, value in tuple(chunkable_kwargs.items())[1:]:
if value.shape[0] != full_batch_size:
raise ValueError(
"All chunkable arguments must have the same batch size, but {} \
has shape {}.".format(key, value.shape)
)
chunks_of_chunkable_args = []
ind_st = 0
while ind_st < full_batch_size:
ind_end = min(ind_st + chunk_size, full_batch_size)
chunk = {key: value[ind_st:ind_end] for key, value in chunkable_kwargs.items()}
chunks_of_chunkable_args.append(chunk)
ind_st = ind_end
for kwargs_chunk in chunks_of_chunkable_args:
current_chunk_size = kwargs_chunk[list(kwargs_chunk.keys())[0]].shape[0]
if replicated_kwargs is not None:
replicated_kwargs_chunk = {
key: value.expand(current_chunk_size, *value.shape[1:])
for key, value in replicated_kwargs.items()
}
kwargs_chunk.update(replicated_kwargs_chunk)
ret = func(**common_kwargs, **kwargs_chunk, **iterated_kwargs)
if isinstance(ret, tuple):
for i, key in enumerate(iterated_kwargs.keys()):
iterated_kwargs[key] = ret[i]
else:
iterated_kwargs[tuple(iterated_kwargs.keys())[0]] = ret
if len(iterated_kwargs) == 1:
return tuple(iterated_kwargs.values())[0]
else:
return tuple(iterated_kwargs.values())
[docs]
def calculate_data_size_gb(
shape: Sequence[int],
dtype: torch.dtype | np.dtype
) -> float:
"""
Calculate the size of the data in GB.
"""
shape = list(shape)
return np.prod(shape) * dtype.itemsize / 1024 ** 3
[docs]
def get_max_batch_size(
probe_shape: Sequence[int],
object_shape: Sequence[int],
double_precision: bool = False,
data_saved_on_device: bool = False,
all_data_shape: Sequence[int] = None,
reconstructor_type: Literal["lsqml"] = "lsqml",
margin_factor: float = 0.2
) -> int:
"""
Estimate the maximum batch size that fits in the available device memory.
We estimate the memory usage using an empirical formula:
```
mem = x0 * n_p * batch_size + x1 * n_p + x2 * object_numel
```
where `n_p = n_modes * probe_size ** 2` and the coefficients `x0`, `x1`, `x2`
were fit from experimental data.
Parameters
----------
probe_shape : Sequence[int]
The shape of the 4D probe, expected to be (n_opr_modes, n_modes, h, w).
object_shape : Sequence[int]
The shape of the object, expected to be (n_slcies, h, w).
double_precision : bool, optional
Whether to use double precision.
data_saved_on_device : bool, optional
Whether the raw data is kept on device.
all_data_shape : Sequence[int], optional
The shape of the data, expected to be (n_points, h, w).
reconstructor_type : Literal["lsqml"], optional
The type of reconstructor. Currently only `lsqml` is supported.
margin_factor : float, optional
The fraction of the device memory to be left free.
Returns
-------
int
The suggested batch size.
"""
if reconstructor_type == "lsqml":
x0 = 6.26e-8
x1 = 3.73e-7
x2 = 1.39e-7
else:
raise ValueError(f"Unknown reconstructor type: {reconstructor_type}")
dtype = torch.float64 if double_precision else torch.float32
n_p = np.prod(list(probe_shape[1:]))
n_o = np.prod(list(object_shape))
if data_saved_on_device:
data_size_gb = calculate_data_size_gb(all_data_shape, dtype)
else:
data_size_gb = 0.0
mem_avail = AcceleratorModuleWrapper.get_module().mem_get_info()[0] * (1 - margin_factor) / 1024 ** 3
mem_compute = mem_avail - data_size_gb
batch_size = (mem_compute - x1 * n_p - x2 * n_o) / (x0 * n_p)
batch_size = batch_size * (8 / dtype.itemsize)
return max(int(batch_size), 1)
[docs]
def auto_transfer_to_device(data: Tensor) -> Tensor:
"""Automatically determine the device that the data should be placed on,
and transfer the data to that device.
The logic of this function is as follows:
1. If `torch.get_default_device()` is `cuda`, transfer the data to `cuda`.
2. If `torch.get_default_device()` is `cpu`, it could either be GPU is unavailable or
intentionally disabled, OR the current code is executed by DataParallel.
1. If `torch.cuda.device_count()` is 0, we assume it is the former case, and
we keep the data as is.
2. If `torch.cuda.device_count()` is not 0, we assume it is the latter case, and
we transfer the data to `cuda`.
"""
accelerator_module_wrapper = AcceleratorModuleWrapper()
if torch.get_default_device().type == accelerator_module_wrapper.get_to_device_string():
return data.to(accelerator_module_wrapper.get_to_device_string())
else:
if accelerator_module_wrapper.get_module().device_count() == 0:
return data
else:
return data.to(accelerator_module_wrapper.get_to_device_string())
[docs]
def clear_memory(task: Optional["PtychographyTask"] = None):
"""Clear the memory of the device used. If a `Task` object is provided,
it will be deleted and the memory will be released.
Parameters
----------
task : PtychographyTask, optional
The `Task` object to be deleted.
"""
accelerator_module_wrapper = AcceleratorModuleWrapper()
if task is not None:
del task
gc.collect()
accelerator_module_wrapper.get_module().empty_cache()
accelerator_module_wrapper.get_module().ipc_collect()
[docs]
def jsonize(val):
"""Convert a value to a JSON-serializable object."""
if isinstance(val, np.generic):
return val.item()
elif isinstance(val, np.ndarray):
return val.tolist()
elif isinstance(val, torch.Tensor):
return val.tolist()
elif isinstance(val, (list, tuple, dict, str, int, float, bool, type(None))):
return val
else:
raise TypeError(f"Object of type {type(val).__name__} is not JSON serializable")