Source code for ptychi.api.task

# Copyright © 2025 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE

from typing import Literal, Union, overload
from types import TracebackType
import random
import logging
import os

import torch
import numpy as np
from torch import Tensor
from numpy import ndarray

import ptychi.api as api
import ptychi.data_structures.object as object
import ptychi.data_structures.opr_mode_weights as oprweights
import ptychi.data_structures.probe as probe
import ptychi.data_structures.probe_positions as probepos
import ptychi.data_structures.parameter_group as paramgrp
import ptychi.maps as maps
from ptychi.io_handles import PtychographyDataset
from ptychi.reconstructors.base import Reconstructor
from ptychi.utils import to_tensor
import ptychi.utils as utils
import ptychi.maths as pmath
from ptychi.timing import timer_utils
import ptychi.movies as movies
from ptychi.device import AcceleratorModuleWrapper
from ptychi.parallel import MultiprocessMixin

logger = logging.getLogger(__name__)


class Task(MultiprocessMixin):
    def __init__(self, options: api.options.base.TaskOptions, *args, **kwargs) -> None:
        pass

    def __enter__(self) -> "Task":
        return self

    @overload
    def __exit__(self, exception_type: None, exception_value: None, traceback: None) -> None: ...

    @overload
    def __exit__(
        self,
        exception_type: type[BaseException],
        exception_value: BaseException,
        traceback: TracebackType,
    ) -> None: ...

    def __exit__(
        self,
        exception_type: type[BaseException] | None,
        exception_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        AcceleratorModuleWrapper.get_module().empty_cache()


[docs] class PtychographyTask(Task): def __init__(self, options: api.options.task.PtychographyTaskOptions, *args, **kwargs) -> None: super().__init__(options, *args, **kwargs) self.options = options self.data_options = options.data_options self.object_options = options.object_options self.probe_options = options.probe_options self.position_options = options.probe_position_options self.opr_mode_weight_options = options.opr_mode_weight_options self.reconstructor_options = options.reconstructor_options self.dataset = None self.object = None self.probe = None self.probe_positions = None self.opr_mode_weights = None self.reconstructor: Reconstructor | None = None self.check_options() self.build()
[docs] def check_options(self): self.options.check()
[docs] def build(self): self.build_random_seed() self.build_default_device() self.build_default_dtype() self.build_logger() self.build_data() self.build_object() self.build_probe() self.build_probe_positions() self.build_opr_mode_weights() self.build_reconstructor()
[docs] def build_random_seed(self): if self.reconstructor_options.random_seed is not None: torch.manual_seed(self.reconstructor_options.random_seed) np.random.seed(self.reconstructor_options.random_seed) random.seed(self.reconstructor_options.random_seed) pmath.set_allow_nondeterministic_algorithms(self.reconstructor_options.allow_nondeterministic_algorithms)
[docs] def build_default_device(self): accelerator_module = AcceleratorModuleWrapper.get_module() if self.detect_launcher() is None: torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device)) else: self.init_process_group() if self.backend == "nccl" and self.n_ranks > accelerator_module.device_count(): raise ValueError( f"Number of ranks ({self.n_ranks}) is greater than the number of devices " f"({accelerator_module.device_count()}). This is not allowed with NCCL backend." ) if self.n_ranks == 1: torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device)) else: logging.info(f"Multi-processing mode detected with {self.n_ranks} ranks.") torch.set_default_device( f"{AcceleratorModuleWrapper.get_to_device_string()}:{self.rank % accelerator_module.device_count()}" ) if accelerator_module.device_count() > 0: cuda_visible_devices_str = "(unset)" if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): cuda_visible_devices_str = os.environ["CUDA_VISIBLE_DEVICES"] logger.info( "Using device: {} (CUDA_VISIBLE_DEVICES=\"{}\")".format( [accelerator_module.get_device_name(i) for i in range(accelerator_module.device_count())], cuda_visible_devices_str, ) ) else: logger.info("Using device: {}".format(torch.get_default_device()))
[docs] def build_logger(self): if self.rank != 0: logger.setLevel(level=logging.ERROR)
[docs] def build_default_dtype(self): torch.set_default_dtype(maps.get_dtype_by_enum(self.reconstructor_options.default_dtype)) utils.set_default_complex_dtype( maps.get_complex_dtype_by_enum(self.reconstructor_options.default_dtype) ) pmath.set_use_double_precision_for_fft( self.reconstructor_options.use_double_precision_for_fft )
[docs] def build_data(self): if self.data_options.free_space_propagation_distance_m < np.inf and self.data_options.fft_shift: logger.warning( "It seems that you are reconstructing near-field data with FFT-shifted diffraction data. " "Is this intended? If not, set `data_options.fft_shift=False`." ) save_on_device = self.data_options.save_data_on_device if self.n_ranks > 1: if save_on_device: logging.warning( "Data must be saved on CPU in multi-processing mode " "but `data_options.save_data_on_device` is set to `True`. " "The current setting will be ignored." ) save_on_device = False self.dataset = PtychographyDataset( self.data_options.data, wavelength_m=self.data_options.wavelength_m, free_space_propagation_distance_m=self.data_options.free_space_propagation_distance_m, fft_shift=self.data_options.fft_shift, save_data_on_device=save_on_device, valid_pixel_mask=self.data_options.valid_pixel_mask, )
[docs] def build_object(self): data = to_tensor(self.object_options.initial_guess) kwargs = { "data": data, "options": self.object_options, } if ( isinstance(self.object_options, api.options.AutodiffPtychographyObjectOptions) ) and ( self.object_options.experimental.deep_image_prior_options.enabled ): self.object = object.DIPPlanarObject(**kwargs) else: self.object = object.PlanarObject(**kwargs)
[docs] def build_probe(self): data = to_tensor(self.probe_options.initial_guess) kwargs = { "data": data, "options": self.probe_options, } if ( isinstance(self.probe_options, api.options.AutodiffPtychographyProbeOptions) ) and ( self.probe_options.experimental.deep_image_prior_options.enabled ): self.probe = probe.DIPProbe(**kwargs) elif ( isinstance(self.probe_options, api.options.PIEProbeOptions) ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): self.probe = probe.SynthesisDictLearnProbe(**kwargs) else: self.probe = probe.Probe(**kwargs)
[docs] def build_probe_positions(self): pos_y = to_tensor(self.position_options.position_y_px) pos_x = to_tensor(self.position_options.position_x_px) data = torch.stack([pos_y, pos_x], dim=1) self.probe_positions = probepos.ProbePositions(data=data, options=self.position_options)
[docs] def build_opr_mode_weights(self): if self.opr_mode_weight_options.initial_weights is None: initial_weights = torch.ones([self.data_options.data.shape[0], 1]) else: initial_weights = to_tensor(self.opr_mode_weight_options.initial_weights) if initial_weights.ndim == 1: # If a 1D array is given, expand it to all scan points. initial_weights = initial_weights.unsqueeze(0).repeat( len(self.position_options.position_x_px), 1 ) self.opr_mode_weights = oprweights.OPRModeWeights( data=initial_weights, options=self.opr_mode_weight_options )
[docs] def build_reconstructor(self): par_group = paramgrp.PlanarPtychographyParameterGroup( object=self.object, probe=self.probe, probe_positions=self.probe_positions, opr_mode_weights=self.opr_mode_weights, ) if self.n_ranks == 1: reconstructor_class = maps.get_reconstructor_by_enum( self.reconstructor_options.get_reconstructor_type() ) else: reconstructor_class = maps.get_multiprocess_reconstructor_by_enum( self.reconstructor_options.get_reconstructor_type() ) reconstructor_kwargs = { "parameter_group": par_group, "dataset": self.dataset, "options": self.reconstructor_options, } self.reconstructor = reconstructor_class(**reconstructor_kwargs) self.reconstructor.build()
[docs] def run(self, n_epochs: int = None, reset_timer_globals: bool = True): """ Run reconstruction either for `n_epochs` (if given), or for the number of epochs given in the options. The internal states of the Task object persists when this function finishes. To run more epochs continuing from the last run, call this function again. Parameters ---------- n_epochs : int, optional The number of epochs to run. If None, use the number of epochs specified in the option object. reset_timer_globals : bool, optional When True (default) the global timing accumulators are cleared before the run. Set to False to continue accumulating timing data across successive calls. """ if movies.MOVIES_INSTALLED and self.reconstructor.current_epoch == 0: movies.api.reset_movie_builders() if reset_timer_globals: timer_utils.clear_timer_globals() self.reconstructor.run(n_epochs=n_epochs)
[docs] def get_data( self, name: Literal["object", "probe", "probe_positions", "opr_mode_weights"] ) -> Tensor: """Get a detached copy of the data of the given name. Parameters ---------- name : Literal["object", "probe", "probe_positions", "opr_mode_weights"] The name of the data to get. Returns ------- Tensor The data of the given name. """ # Deep image prior objects and probes need to be generated # before fetching to avoid issues with multi-GPU. if name == "object" and isinstance(self.object, object.DIPPlanarObject): self.object.generate() elif name == "probe" and isinstance(self.probe, probe.DIPProbe): self.probe.generate() return getattr(self, name).data.detach()
[docs] def get_data_to_cpu( self, name: Literal["object", "probe", "probe_positions", "opr_mode_weights"], as_numpy: bool = False, ) -> Union[Tensor, ndarray]: data = self.get_data(name).cpu() if as_numpy: data = data.numpy() return data
[docs] def get_probe_positions_y(self, as_numpy: bool = False) -> Union[Tensor, ndarray]: data = self.probe_positions.data[:, 0].detach() if as_numpy: data = data.cpu().numpy() return data
[docs] def get_probe_positions_x(self, as_numpy: bool = False) -> Union[Tensor, ndarray]: data = self.probe_positions.data[:, 1].detach() if as_numpy: data = data.cpu().numpy() return data
[docs] def copy_data_from_task( self, task: "PtychographyTask", params_to_copy: tuple[str, ...] = ("object", "probe", "probe_positions", "opr_mode_weights") ) -> None: """Copy data of reconstruction parameters from another task object. Parameters ---------- task : PtychographyTask The task object to copy from. params_to_copy : tuple[str, ...], optional The parameters to copy. By default, copy all parameters. """ with torch.no_grad(): for param in params_to_copy: if param == "object": self.reconstructor.parameter_group.object.set_data( task.get_data("object") ) elif param == "probe": self.reconstructor.parameter_group.probe.set_data( task.get_data("probe") ) elif param == "probe_positions": self.reconstructor.parameter_group.probe_positions.set_data( task.get_data("probe_positions") ) elif param == "opr_mode_weights": self.reconstructor.parameter_group.opr_mode_weights.set_data( task.get_data("opr_mode_weights") ) else: raise ValueError(f"Invalid parameter name: {param}")
[docs] def set_large_tensor_device( self, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Move large task buffers between CPU and a target device. This helper is aimed at multi-task workflows where only one task is active on the accelerator at a time. Call it with ``device="cpu"`` to offload the heavy object/probe/diffraction buffers to system memory, and call it again (without arguments, or with an explicit device string) before resuming the task to bring the tensors back to the accelerator. Parameters ---------- device : str | torch.device | None, optional Target device for the large buffers. If None, tensors are moved back to the current default device. If a string is given, it must be either "cpu" or "cuda". """ if device is None: device = torch.get_default_device() device = torch.device(device) if self.reconstructor is None: raise RuntimeError("Reconstructor is not built yet.") parameter_group = self.reconstructor.parameter_group with torch.no_grad(): # Move object and probe buffers. parameter_group.object.to(device) parameter_group.probe.to(device) # Move diffraction patterns. self.dataset.patterns = self.dataset.patterns.to(device) # Keep dataset bookkeeping in sync with where patterns live. self.dataset.save_data_on_device = device.type != "cpu" # Move intermediate variables in forward model. self.reconstructor.forward_model.move_intermediate_variables_to_device(device) if device.type == "cpu": AcceleratorModuleWrapper.get_module().empty_cache()
[docs] def get_options_as_dict(self) -> dict: return self.options.get_dict()
[docs] def load_options_from_dict(self, d: dict) -> None: self.options.load_from_dict(d) self.data_options = self.options.data_options self.object_options = self.options.object_options self.probe_options = self.options.probe_options self.position_options = self.options.probe_position_options self.opr_mode_weight_options = self.options.opr_mode_weight_options self.reconstructor_options = self.options.reconstructor_options
def __exit__(self, exc_type, exc_value, exc_tb): del self.object del self.probe del self.probe_positions del self.opr_mode_weights del self.reconstructor del self.dataset super().__exit__(exc_type, exc_value, exc_tb)