Source code for SPCSim.data_loaders.perpixel_loaders

import numpy as np
import cv2
import torch

[docs] class PerPixelLoader: r""" This class loads a custom RGB-D image with each row consisting of different distance values and avg signal-background photons and R columns where R is the number of runs to be captured for each pixel with a specific SBR and distance combination. .. note:: If A is the output transient matrix then A[i,j] gives readings for ``dist_idx = i % len(SBR_list)``, ``SBR_Pair_idx = (i/num_dists)``, and ``Run_idx = j``. """ def __init__(self, num_dists = 1, min_dist = 0, max_dist = 1, sig_bkg_list = [ [1,1] ], tmax = 100, #time period in nano-seconds num_runs = 1, device = 'cpu', ): self.dmax = 3*1e8*tmax*1e-9/2 # scene distance in meters dmax = (c*tmax)/2 self.num_runs = num_runs self.num_dists = num_dists self.min_dist = min_dist self.max_dist = max_dist self.sig_bkg_list = sig_bkg_list self.device = device self.tmax = tmax self.Nr = None self.Nc = None
[docs] def get_data(self): r"""Method to generate the rgb-d data along with the average signal and background photon flux per cycle. Returns: data (dictionary): Dictionary containing the generated RGB, distance image, albedo, signal and bkg flux """ dist_list = torch.tensor(np.linspace(self.min_dist, self.max_dist, num =self.num_dists).reshape(-1, 1)) nr,nc = [len(self.sig_bkg_list)*self.num_dists,self.num_runs] alpha_sig = torch.zeros(nr, nc, 1) alpha_bkg = torch.zeros(nr, nc, 1) dist_f = torch.zeros(nr, nc, 1) for row in range(len(self.sig_bkg_list)): alpha_sig[row*self.num_dists:(row+1)*self.num_dists, :] = self.sig_bkg_list[row][0] alpha_bkg[row*self.num_dists:(row+1)*self.num_dists, :] = self.sig_bkg_list[row][1] dist_f[row*self.num_dists:(row+1)*self.num_dists, :] = dist_list.view(-1,1,1) albedo = torch.from_numpy(np.ones((nr,nc))).to(self.device) rgb = torch.from_numpy(np.ones((nr,nc,3))).to(self.device) gt_dist_factor = dist_f.view(nr,nc).to(self.device) gt_dist = gt_dist_factor*self.dmax self.Nr = nr self.Nc = nc data = { 'rgb':rgb.to(self.device), 'albedo':albedo.to(self.device), 'gt_dist':gt_dist.to(self.device), 'alpha_sig':alpha_sig.to(self.device), 'alpha_bkg':alpha_bkg.to(self.device), 'loader_id':"perpixel" } return data
[docs] def get_row(self, sbr_idx=0, dist_idx=0): return sbr_idx*self.num_dists + dist_idx