Learning to Sense for Coded Diffraction Imaging

In this paper, we present a framework to learn illumination patterns to improve the quality of signal recovery for coded diffraction imaging. We use an alternating minimization-based phase retrieval method with a fixed number of iterations as the iterative method. We represent the iterative phase retrieval method as an unrolled network with a fixed number of layers where each layer of the network corresponds to a single step of iteration, and we minimize the recovery error by optimizing over the illumination patterns. Since the number of iterations/layers is fixed, the recovery has a fixed computational cost. Extensive experimental results on a variety of datasets demonstrate that our proposed method significantly improves the quality of image reconstruction at a fixed computational cost with illumination patterns learned only using a small number of training images.


Introduction
Coded diffraction imaging is a specific instance of Fourier phase retrieval problems. Phase retrieval refers to a broad class of nonlinear inverse problems where we seek to recover a complex-(or real-) valued signal from its phase-less (or sign-less) measurements [1][2][3][4]. In practice, these problems often arise in coherent optical imaging where an image sensor records the intensity of the Fourier measurements of the object of interest. In coded diffraction imaging, the signal of interest is modulated by a sequence of known illumination patterns/masks before observing the Fourier intensity at the sensor [2,4]. Applications include X-ray crystallography [5,6], astronomy [7,8], microscopy [9][10][11][12], speech processing and acoustics [13,14], and quantum mechanics [15,16]. Similar to other signal recovery problems in various imaging and signal processing tasks [4,5,11,17,18], iterative methods are also used in coded diffraction imaging. In this paper, we present a framework to design the illumination patterns for better signal recovery for coded diffraction imaging using a fixed-cost iterative method in a data-driven manner.
Let us denote the signal of interest as x ∈ R n or C n that is modulated by T illumination patterns D = {d 1 , . . . , d T }, where d t ∈ R n or C n . The amplitude of sensor measurements for t th illumination pattern can be written as where F denotes the Fourier transform operator, and denotes an element-wise product. We note that real sensor measurements are proportional to the intensity of the incoming signal (i.e., square of the Fourier transform). In practice, however, solving the inverse problem with (non-square) amplitude measurements provides better results [19,20]; therefore, we use the amplitude measurements throughout the paper.
To recover the signal x from the the observed measurements, we can solve the following optimization problem: In recent years, a number of iterative algorithms have been proposed for solving the problem in (2), which includes lifting-based convex methods, alternating minimization-based non-convex methods, and greedy methods [2,[21][22][23][24].
Our goal is to learn a set of illumination patterns to optimize the recovery of an alternating minimization (AltMin) algorithm for solving the problem in (2). The AltMin method can be viewed as an unrolled gradient descent network, as shown in Figure 1, where we fix the steps at every iteration and the total number of iterations for AltMin. One forward pass through the unrolled network is equivalent to K iterations of the AltMin algorithm using given illumination patterns. We can increase or decrease the number of iterations for better accuracy or faster run-time. To keep the computational complexity of the recovery algorithm low, we keep the total number of iterations small (e.g., K = 50). At the training stage, we optimize over the illumination patterns to minimize the error between the AltMin outputs after K iterations and the ground truth training images. At the test time, we solve the problem in (2) using K AltMin iterations with the learned illumination patterns (equivalent to one forward pass). We evaluated our method on different image datasets and compared against existing methods for coded diffraction imaging. We demonstrate that our proposed method of designing illumination patterns for a fixed-cost algorithm outperforms existing methods both in terms of accuracy and speed.

Learnab l e S e n s i n g S y s t e m
Learned Patterns ..
x 0 x K E s t i m a t e d I m a g e at Every Layer in Unrolled N e t w o r k Figure 1. Pipeline of our proposed framework at inference time. Our framework mainly contains two components: (1) a learnable sensing system that updates the illumination patterns during training time, but at inference time the learned illumination patterns are fixed; (2) a fixed unrolled network that runs phase retrieval process to recover the original signal x form measurements Y. The number of layers in the network is fixed to K. Steps at every iteration are fixed and depicted as an unrolled network (details can be found in Algorithm 1).
The main contributions of this paper are as follows.
• Low cost inference: We learn illumination patterns for coded diffraction imaging using the unrolled network formulation of a classical AltMin method. We show that with our learned illumination patterns, the unrolled AltMin method outperforms other computationally complex algorithms and provides superior image reconstruction within a much shorter time. • Learning from small dataset: We use only a small number of training samples and can learn illumination patterns that are highly effective for image reconstruction. It is crucial for real-life applications because finding training samples can be challenging in practice. • Robust sensor design: The patterns learned on a given dataset generalize to different datasets and provide robust reconstruction for shifted and flipped versions of the target samples. It does not drastically degrade under noisy measurements. Our learned illumination patterns can also help other algorithms achieve better performance even though they are not used for training.

Related Work
Phase Retrieval and Coded Diffraction Patterns. A Fourier phase retrieval problem arises in a number of imaging systems because standard image sensors can only record intensity of the observed measurements. This problem has been extensively studied over the last five decades in optics, signal processing, and optimization [3][4][5]25,26]. Coded diffraction imaging is a physically realistic setup in which we can first modulate the signal of interest and then collect the intensity measurements [18,27]. The modulation can be performed using a spatial light modulator or custom transparencies [10,11,28]. The recovery problems involve solving a phase retrieval problem; the presence of modulation patterns makes this a more tractable problem compared to classical Fourier phase retrieval [18].
The algorithms for solving phase retrieval problem can be broadly divided into non-convex and convex methods. Classical algorithms for phase retrieval rely on solving the underlying non-convex problem using alternating minimization. Amplitude flow [29,30], Wirtinger flow [31,32], and alternating minimization (AltMin) [22,23,33] are such methods that solve the non-convex problem. Convex methods usually lift the non-convex problem of signal recovery from quadratic measurements into a convex problem of low-rank matrix recovery from linear measurements. The PhaseLift algorithm [2] and its variations [18,21] can be considered under this class. Other algorithms, such as PhaseMax [34,35] and PhaseLin [36], use convex relaxation to solve the non-convex phase retrieval problem without lifting the problem to a higher dimension. We can also incorporate prior knowledge about the signal structure (e.g., sparsity, support, or positivity) in the recovery process constraints [22,29,32,37,38].
Data-Driven Approaches for Phase Retrieval. Recently, the idea of replacing the classical (hand-designed) signal priors with deep generative priors for solving inverse problems has been explored in different works [39,40]. Refs. [23,26,[41][42][43][44] focused especially on solving phase retrieval problems with generative priors. Another growing trend is learning the solution of inverse problems (including phase retrieval) in an end-to-end manner, where deep networks are trained to learn a mapping from sensor measurements to the signal of interest using a large number of measurement-signal pairs. A few examples demonstrating the benefit of the data-driven approaches include robust phase retrieval [20], Fourier ptychographic microscopy [45], holographic image reconstruction [46], and correlography for non-line-of-sight imaging [47].
Although our method is partially driven by data, our goal is not to learn a signal prior or a mapping from measurements to signal. We use a very small dataset (consisting 32 or 128 images only) to learn the illumination patterns for a fixed recovery algorithm. Furthermore, the patterns we learn on one class of images provide good results on other types of images. Apart from the great flexibility and generalization, our method uses a fixed number of iterations of the well-defined AltMin routine, which is parameterfree during inference (except the step size) compared to end-to-end or generative priorbased approaches.
The approach we used for optimizing over the AltMin routine to learn illumination patterns is broadly known as unrolling networks. Iterative methods for solving the inverse problems, such as AltMin or other first-order methods, can be represented as unrolled networks. Every layer of such a network performs the same steps as a single iteration of the original method [48][49][50][51][52][53][54][55][56][57]. Some parameters of the iterative steps can be learned from data (e.g., step size, denoiser, or threshold parameters), but the basic structure and physical forward model are kept intact.
Learn to Sense. Data-driven deep learning methods have also been used to design the sensing system, especially in the context of compressive sensing and computational imaging [58][59][60][61][62][63]. The main objective in these methods is similar to ours, which is to find the sensor parameters to recover the best possible signal/image from the sensor measurements. The sensor parameters may involve selection of samples/frames, design of sampling waveforms, or illumination patterns as we discuss in this paper. In contrast to most of the existing methods that learn a deep network to solve the inverse problem, our method uses a predefined iterative method as an unrolled network, while learning the illumination patterns using a small number of training images. Unrolled networks for solving non-linear inverse problems have been used in [45,64]. Ref. [45] proposes learning sensors for Fourier ptychographic microscopy, whereas [64] designs sensing patterns for coded illumination imaging. One might find a similarity between [64] and our problem formulation. In principle, the sensor can be treated as the first layer of the network with some physical constraints on the parameters [64]. However, the method in [64] uses an unrolled network to learn the sensing parameters for a quantitative phase imaging problem under the "weak object approximation". This approximation turns the original nonlinear problem into a linear inverse problem. This assumption is only applicable where the target objects have a small scatter term (e.g., biological samples in closely index-matched fluid). In our setup, we do not make any such assumptions on target object and solve the original nonlinear coded diffraction imaging problem. This potentially makes our algorithm suitable for more general applications than [64].

Proposed Method
Our proposed method for learning illumination patterns can be divided into two parts. The first (inner) part involves solving the phase retrieval problem with given coded diffraction patterns using AltMin as an unrolled network (see block diagram in Figure 1). The second part is updating the illumination patterns based on backpropagating the image reconstruction loss. These two parts provide optimized image reconstruction and illumination patterns. Pseudocodes for both parts are listed in Algorithms 1 and 2.

Algorithm 2 Learning illumination patterns
We use N training images (x 1 , . . . , x N ) to learn T illumination patterns that provide the best reconstruction using a predefined (iterative) phase retrieval algorithm. Furthermore, to ensure that the illumination patterns are physically realizable, we constrain their values to be in the range [0, 1]. We use a sigmoid function over unconstrained parameters Θ = {θ 1 , . . . , θ T } to define the illumination patterns; that is, d t = sigmoid(θ t ) for all t = 1, . . . , T.
Phase retrieval with alternating minimization (AltMin). Given measurements Y = {y 1 , . . . , y T } and illumination patterns D = {d 1 , . . . , d T }, we seek to solve the CDP phase retrieval problem by minimizing the loss function defined in (2) as Although the loss function in (3) is non-convex and non-smooth with respect to x, we can minimize it using the well-known alternating minimization (AltMin) with gradient descent [22,33]. In AltMin formulation, we define a new variable for the estimated phase of linear measurements as p t = phase[F (d t x)] and reformulate the loss function in (3) into The gradient with respect to x can be computed as where F * denotes the inverse Fourier transform, and d * t is the conjugate of pattern d t . We can update the estimate at every iteration as where α k−1 denotes the step size. Another way is to directly solve for x k such that ∇ x L x,p = 0. The closed-form solution is We compared these two strategies and found that single-step gradient descent tends to work well in practice, and the closed-form solution does not show an advantage over the single-step gradient descent. In our implementation, we used the former strategy (Algorithm 1) and fixed a step size α for all iterations. The unrolled network has K layers that implement K iterations of the gradient descent, and the final estimate is denoted as x K .
Choice of initialization is important, and our method can handle different types of initialization. Zero initialization, where every pixel of the initial guess of x 0 is 0, is the simplest and cost-free method. Many recent phase retrieval algorithms [30,31,33,35] use spectral initialization, which tries to find a good initial estimate. However, it requires computing the principal eigenvector of the following positive semidefinite matrix, In our experiments, we observed that spectral initialization does not provide a significant improvement in terms of image reconstruction and that our algorithm can perform very well using the overhead-free zero initialization.
Learning illumination patterns. To learn a set of illumination patterns that provide the best reconstruction with the predefined iterative method (or the unrolled network), we seek to minimize the difference between the original training images and their estimates. In this regard, we minimize the following quadratic loss function with respect to Θ: where x K n (Θ) denotes the solveCDP estimate of nth training image for the given values of Θ. Note that for given real values of Θ = {θ 1 , . . . , θ T }, we can define illumination patterns as d t = σ(θ t ), where σ(·) is the sigmoid function. We can define sensor measurements for x n as y t,n = |F (d t x n )| = p * t,n F (d t x n ) for t = 1, . . . , T and n = 1, . . . , N, where p t,n = phase[F (d t x n )] is the phase of the original complex-valued signal.
We can use the recursive expression of the signal estimate in (6) and the gradient in (5) to represent the estimate of x n at iteration/layer k with the given values of Θ as where p k t,n = phase[F (d t x k n (Θ))]. We can compute the gradient of the loss function in (8) with respect to any θ t in a recursive manner as follows.
where J θ t (x K n (Θ)) denotes the Jacobian matrix of the signal estimate with respect to θ τ . We can now write the product of the Jacobian matrix with a vector u as where J θ τ (x 0 n ) = 0 for all n, τ. Here, we assume initial estimate x 0 n = 0 and α k = α for k = 1, . . . , K. We also assume that the phase of the measurements or the signal estimates do not change with small changes in Θ. The overall gradient of the reconstruction loss with respect to the parameters Θ can be computed in a recursive manner (backpropagation ) using element-wise products and forward/inverse Fourier transform operations at every iteration/layer.
We can use gradient descent to find the optimal Θ using Equation (10). We can update the estimate at every iteration of gradient descent as where β denotes the learning rate for the gradient descent.
In practice, we can also compute the gradient using auto-differentiation. In our experiments, we used Adam optimizer in PyTorch [65,66] to minimize the loss function in (8). A summary of the algorithm for learning the illumination patterns is also listed in Algorithm 2. Our code will be available at https://github.com/CSIPlab/learned-codeddiffraction (accessed on 12 December 2022).

Experiments
Datasets. We used MNIST digits, Fashion MNIST (F. MNIST), CIFAR10, SVHN, and CelebA datasets for training and testing in our experiments. We used 128 images from each of the datasets for training and another 1000 images for testing. To make the tiny-image datasets uniform, we reshaped all of them to 32 × 32 size with grayscale values. Images in CelebA dataset have 218 × 178 pixels. We first converted all the images to grayscale, cropped 178 × 178 region in the center, and resized them to 200 × 200.
Measurements. We used the amplitude of the 2D Fourier transform of the images modulated with T illumination patterns as the measurements. Unless otherwise mentioned, we used noiseless measurements. We report results for measurements with Gaussian and Poisson noise in Section 4.7.
Computing platform. We performed all the experiments using a computer equipped with Intel Core i7-8700 CPU and NVIDIA TITAN Xp GPU. We learned the illumination patterns using a PyTorch implementation, but we also implemented our algorithm in Matlab to provide a fair runtime comparison with existing phase retrieval methods.

Setup and Hyper-Parameter Search
The hyper-parameters include the number of iterations (K), step size α, and the number of training samples N. We set the default value of K = 50, but we show in supplementary material that K can be adjusted as a trade-off between better reconstruction quality and shorter runtime. We tested all methods for T = {2, 3, 4, 8} to evaluate cases where signal recovery is hard, moderate, and easy. Through grid search, we found that it provides the best results over all datasets when α = 4/T. We also studied the effect of the number of training images and found that illumination patterns learned on 32 randomly selected images provide good recovery over the entire dataset. The test accuracy improves slightly as we increase the number of training samples. To be safe, we used 128 training images in all our experiments. Unless otherwise mentioned, the images are constrained to be in [0, 1] range for our experiments.

Comparison of Random and Learned Patterns
To demonstrate the advantages of our learned illumination patterns, we compared the performance of the learned and random illumination patterns on five different datasets. We learned a set of T = {2, 3, 4, 8} illumination patterns on 128 training images from a dataset and tested them on 1000 test images from the same dataset. For random patterns, we drew T independent patterns from uniform (0,1) distribution and tested their performance on the same 1000 samples that we used for the learned case. Unless otherwise mentioned, we repeated this process 30 times and chose the best result to compare with the results for the learned illumination patterns. The average peak signal-to-noise ratio (PSNR) over all 1000 test image reconstructions was presented in Table 1, which shows that the learned illumination patterns perform significantly better than the random patterns for all values of T. In addition to that, we observed a transition in the performance for T = 3, where random patterns provided poor quality reconstructions and learned patterns provided reasonably high quality reconstructions. Furthermore, the learned patterns provided very high quality reconstructions for T ≥ 4. To highlight this effect, we show a small set of reconstructed images and histograms of PSNRs of some reconstructed images from the learned and random illumination patterns in Figure 2 for T = 4 patterns. The result suggests that the learned illumination patterns demonstrate consistently better performance compared to random illumination patterns. We demonstrate the corresponding learned illumination patterns in Figure 2. Visually, illumination patterns learned for the same dataset look similar, and patterns learned on different datasets look different.

Comparision with Existing Methods
We show a comparison with different existing methods using different datasets. Existing methods can be divided into four broad categories:
We compare the performance of our method with these methods in terms of reconstruction quality and computation time. For algorithms in [1,25,30,31,35], we used the PhasePack [27] package. In our comparison, we used four illumination patterns and restricted all the illumination patterns in the range of [0, 1]. For all the PhasePack algorithms, we used the default spectral initialization. We observed that different algorithms have different computational complexity in each iteration. Thus, a comparison in terms of the number of maximum iterations in all algorithms is not fair. To overcome this issue, we set the error tolerance (tol = 10 −6 ) and customized the maximum number of iterations in each algorithm to have comparable computations or performance. Specifically, we set the maximum iterations to be 100 for HIO and GS, and 2000 for Wirtinger Flow, Amplitude Flow, and PhaseMax. For our proposed method, we wanted to keep the number of iterations low (20,50,100). To make our runtime comparable with the PhasePack algorithms, we implemented our original Python code in Matlab.
For deep generative models, we used a modified version of the publicly available code for [43]. The code only provided pretrained DCGAN models for MNIST and F. MNIST; therefore, we trained our DCGAN models on the other datasets. This method is noticeably time-consuming because it optimizes over the latent vector for the deep model and uses 2000 iterations for each image where each iteration requires a forward and backward pass through the deep model. The patterns drawn from the uniform (0,1) range did not provide us good reconstruction with the Deep Model; therefore, we tested this method using random patterns drawn uniformly from the [−1, 1] range and learned patterns that we manually scaled to [−1, 1]. The reconstruction results for the Deep Model also directly depend on the quality of the trained generative models. In our experiments, we were not able to generate images with PSNR higher than 30dB using the generative models.
We tested all the methods using random illumination patterns and the learned illumination patterns using K = 50 in our method. For the case of random illumination, we selected the best PSNR from five independent trials and report the average computation time for each experiment. In all the cases, we tuned the parameters that provide best results.
The reconstruction PSNR (in dB) and runtime (in seconds) per image are reported in Table 2 and 3, respectively. We observed that our proposed method with learned patterns performed significantly better than all other algorithms in terms of both reconstruction quality and runtime. We also observed that if we increase the number of iterations for other methods, their reconstruction quality improves beyond the numbers reported in Table 2, but this happens at the expense of much longer computation time.

Generalization on Different Algorithms
An interesting attribute of our learned patterns is that they can be used with different algorithms. Although we learned our illumination patterns using AltMin approach, it performs well for other algorithms. We observe in Table 2 that our learned patterns provide better results compared to Random patterns with almost all the phase retrieval algorithms for all the datasets, even though the patterns were not optimized for those algorithms. These results demonstrate the robust performance of our learned illumination patterns.

Generalization on Different Datasets
To explore the generalizability of our learned illumination patterns, we used patterns learned on one dataset to recover images from another. The results are shown in Table 4. As we can see in the table, the diagonal numbers are generally the best, and off-diagonal numbers are generally better than the ones with random illumination patterns.  Figure 3 shows the performance of the learned and random illumination patterns as we increased K to 200 at test time using the patterns learned for K = 50. The number of illumination patterns is T = 4. Random illumination patterns were selected as the best out of 30 trials. The learned illumination patterns were trained on 128 training images and number of iterations K = 50 during training. We observed that with the learned patterns the image reconstruction process converges faster and is more stable (smaller variance) compared to the case with random patterns. The red curve in Figure 3 has a steeper slope and narrower shades. Besides the default setting for K = 50, we also learn the illumination patterns for different values of K.   Figure 4 shows that we can recover images in a small number of iterations if we use learned illumination patterns. We also observe that we can perform better if we use more iterations in testing than in training. We chose K = 50 for most of the experiments as a trade-off between computational cost and reconstruction performance.  . Reconstruction quality vs. number of iterations (layers) at test time (i.e., K is different for training and testing with T = 4). We show an error bar of ±0.25σ for each dataset. In (a,b), we fixed K (K = 10, 20) and tested using different K. In (c), we trained and tested using the same number of layers.

Noise Response
To investigate the robustness of our method to noise, we trained our illumination patterns on noiseless measurements obtained from the training datasets. We then added Gaussian and Poisson noise at different levels to the measurements from the test datasets. Poisson noise or shot noise is the most common in the imaging systems, which we add following the approach in [20,68]. Let us denote the i th element of measurement vector corresponding to t th illumination pattern, y t as y t (i) = |z t (i)| + η t (i), for i = 1, 2, . . . , m, where η t (i) ∼ N (0, λ|z t (i)|), and z t = F (d t x). We varied λ to generate noise at different signal-to-noise ratio (SNR) levels. Poisson noise affects larger values in measurements with higher strength than the smaller values. Since the sensors can measure only positive measurements, we kept the measurements positive by applying the ReLU function after noise addition. We expected the reconstruction to be affected by noise as we did not use any denoiser. We observe the effect of noise in Figure 5 with illumination patterns learned under a noiseless setup. Even though noise affects the reconstructions, we can obtain reasonable reconstruction up to a certain level of noise. The relationship between noise level and reconstruction performance also indicates that our phase retrieval system is quite stable. Here, we show a shaded error bar of ±0.25σ for each dataset.
We ran another set of experiments where we learned a different set of illumination patterns at different noise levels by introducing measurement noise during training. In Table 5, we report results for the MNIST and CIFAR10 datasets at different levels of Poisson noise introduced during training and testing. We show the performance of some comparable approaches to our learned patterns and random patterns. For random patterns, we reported the results for the best out of five runs. We can observe that even under the presence of high noise (0-20 dB), the learned illumination patterns using our approach performed reasonably well. We observed a performance boost with our learned patterns for 5 dB or higher SNR.

Mismatch in Training and Test Images
In our final experiment, we tested illumination patterns trained on upright images to recover shifted and rotated images. Our results in Figures 6 and 7 show that the learned patterns reliably recovered images regardless of the position or orientation. This is not surprising because we do not learn to represent images or solve the phase retrieval problem using the training data; instead, we only learned the illumination patterns using a predefined AltMin-based recovery algorithm. In contrast, data-driven methods that learn to solve the inverse problem may suffer if the distribution of test images differs significantly from the training images.

Conclusions
We presented a framework to learn the illumination patterns for coded diffraction imaging by formulating an iterative phase retrieval algorithm as a fixed unrolled network. We learned the illumination patterns using a small number of training images via backpropagation. Our results demonstrate that the learned patterns provide near-perfect reconstruction, whereas random patterns fail. The number of iterations in our algorithm provides a clear trade-off between reconstruction accuracy and runtime. In addition, the learning process of our illumination patterns is highly data efficient and requires only a small number of training samples. The learned patterns generalize to different datasets and algorithms that were not used during training.