From 7efcfaae0ca205a375c658eee8d9c65bd6a58810 Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 7 Mar 2025 19:07:55 +0800 Subject: [PATCH] feat: Add base tracer and rendering utilities for ray tracing - Implemented `BaseTracer` class with configurable tracing parameters - Created `RenderBuffer` dataclass for flexible rendering data management - Added `SphereTracer` with advanced ray marching and surface tracing methods - Introduced utility functions for parameter setting and gradient computation - Supported various tensor operations and transformations in render buffer --- code/tracer/base_tracer.py | 58 +++++++++ code/tracer/render_buffer.py | 152 ++++++++++++++++++++++ code/tracer/sphere_tracer.py | 246 +++++++++++++++++++++++++++++++++++ 3 files changed, 456 insertions(+) create mode 100644 code/tracer/base_tracer.py create mode 100644 code/tracer/render_buffer.py create mode 100644 code/tracer/sphere_tracer.py diff --git a/code/tracer/base_tracer.py b/code/tracer/base_tracer.py new file mode 100644 index 0000000..cbb73de --- /dev/null +++ b/code/tracer/base_tracer.py @@ -0,0 +1,58 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import numpy as np + +# This is the bridge between an argparse based approach and a non-argparse one +def setparam(args, param, paramstr): + argsparam = getattr(args, paramstr, None) + if param is not None or argsparam is None: + return param + else: + return argsparam + +class BaseTracer(object): + """Virtual base class for tracer""" + + def __init__(self, + args = None, + camera_clamp : list = None, + step_size : float = None, + grad_method : str = None, + num_steps : int = None, # samples for raymaching, iterations for sphere trace + min_dis : float = None): + + self.args = args + self.camera_clamp = setparam(args, camera_clamp, 'camera_clamp') + self.step_size = setparam(args, step_size, 'step_size') + self.grad_method = setparam(args, grad_method, 'grad_method') + self.num_steps = setparam(args, num_steps, 'num_steps') + self.min_dis = setparam(args, min_dis, 'min_dis') + + self.inv_num_steps = 1.0 / self.num_steps + self.diagonal = np.sqrt(3) * 2.0 + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, net, ray_o, ray_d): + """Base implementation for forward""" + raise NotImplementedError diff --git a/code/tracer/render_buffer.py b/code/tracer/render_buffer.py new file mode 100644 index 0000000..3c1756f --- /dev/null +++ b/code/tracer/render_buffer.py @@ -0,0 +1,152 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from dataclasses import asdict, astuple, fields, dataclass +from typing import Union, List + +import torch +import numpy as np + +@dataclass +class RenderBuffer: + x : Union[torch.Tensor, None] = None + min_x : Union[torch.Tensor, None] = None + hit : Union[torch.Tensor, None] = None + depth : Union[torch.Tensor, None] = None + relative_depth : Union[torch.Tensor, None] = None + normal : Union[torch.Tensor, None] = None + rgb : Union[torch.Tensor, None] = None + shadow : Union[torch.Tensor, None] = None + ao : Union[torch.Tensor, None] = None + view : Union[torch.Tensor, None] = None + err : Union[torch.Tensor, None] = None + albedo : Union[torch.Tensor, None] = None + + def __iter__(self): + return iter(astuple(self)) + + def __add__(self, other): + def _proc(pair): + if None not in pair: + return torch.cat(pair) + elif pair[0] is not None and pair[1] is None: + return pair[0] + elif pair[0] is None and pair[1] is not None: + return pair[1] + else: + return None + + return self.__class__(*(map(_proc, zip(self, other)))) + + def _apply(self, fn): + data = {} + for f in fields(self): + attr = getattr(self, f.name) + data[f.name] = None if attr is None else fn(attr) + return self.__class__(**data) + + def cuda(self): + fn = lambda x : x.cuda() + return self._apply(fn) + + def cpu(self): + fn = lambda x : x.cpu() + return self._apply(fn) + + def detach(self): + fn = lambda x : x.detach() + return self._apply(fn) + + def byte(self): + fn = lambda x : x.byte() + return self._apply(fn) + + def reshape(self, *dims : List[int]): + fn = lambda x : x.reshape(*dims) + return self._apply(fn) + + def transpose(self): + fn = lambda x : x.permute(1,0,2) + return self._apply(fn) + + def numpy(self): + fn = lambda x : x.numpy() + return self._apply(fn) + + def float(self): + fn = lambda x : x.float() + return self._apply(fn) + + def exrdict(self): + _dict = asdict(self) + _dict = {k:v for k,v in _dict.items() if v is not None} + if 'rgb' in _dict: + _dict['default'] = _dict['rgb'] + del _dict['rgb'] + return _dict + + def image(self): + # Unfinished + norm = lambda arr : ((arr + 1.0) / 2.0) if arr is not None else None + bwrgb = lambda arr : torch.cat([arr]*3, dim=-1) if arr is not None else None + rgb8 = lambda arr : (arr * 255.0) if arr is not None else None + + #x = rgb8(norm(self.x)) + #min_x = rgb8(norm(self.min_x)) + hit = rgb8(bwrgb(self.hit)) + depth = rgb8(bwrgb(self.relative_depth)) + normal = rgb8(norm(self.normal)) + #ao = 1.0 - rgb8(bwrgb(self.ao)) + rgb = rgb8(self.rgb) + #return self.__class__(x=x, min_x=min_x, hit=hit, depth=depth, normal=normal, ao=ao, rgb=rgb) + return self.__class__(hit=hit, normal=normal, rgb=rgb, depth=depth) + + @staticmethod + def mean(*rblst): + rb = RenderBuffer() + n = len(rblst) + for _rb in rblst: + for f in fields(_rb): + attr = getattr(_rb, f.name) + dest = getattr(rb, f.name) + if attr is not None and dest is not None: + setattr(rb, f.name, dest + attr) + elif dest is None: + setattr(rb, f.name, attr) + + for f in fields(rb): + dest = getattr(rb, f.name) + if dest is not None: + setattr(rb, f.name, dest / float(n)) + return rb + +if __name__ == '__main__': + hw = 1024 + hit = torch.zeros(hw, hw, 1).bool() + rgb = torch.zeros(hw, hw, 3).float() + + rb0 = RenderBuffer(hit=hit.clone(), rgb=rgb.clone()) + rb1 = RenderBuffer(hit=hit, rgb=rgb) + rb1.rgb += 1.0 + + avg = RenderBuffer.mean(rb0, rb1) + + import pdb; pdb.set_trace() diff --git a/code/tracer/sphere_tracer.py b/code/tracer/sphere_tracer.py new file mode 100644 index 0000000..7e475cc --- /dev/null +++ b/code/tracer/sphere_tracer.py @@ -0,0 +1,246 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +import torch +import torch.nn.functional as F +import torch.nn as nn +import numpy as np + +from lib.utils import PerfTimer +from lib.diffutils import gradient +from lib.geoutils import sample_unif_sphere +from lib.tracer.RenderBuffer import RenderBuffer + +from lib.PsDebugger import PsDebugger + +from base_tracer import BaseTracer + +from sol_nglod import aabb + +class SphereTracer(BaseTracer): + + def forward(self, net, ray_o, ray_d): + """Native implementation of sphere tracing.""" + timer = PerfTimer(activate=False) + nettimer = PerfTimer(activate=False) + + # Distanace from ray origin + t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) + + # Position in model space + x = torch.addcmul(ray_o, ray_d, t) + + cond = torch.ones_like(t).bool()[:,0] + x, t, cond = aabb(ray_o, ray_d) + + normal = torch.zeros_like(x) + # This function is in fact differentiable, but we treat it as if it's not, because + # it evaluates a very long chain of recursive neural networks (essentially a NN with depth of + # ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit + # locations, where additional quantities (normal, depth, segmentation) can be determined. The + # gradients will propagate only to these locations. + with torch.no_grad(): + + d = net(x) + + dprev = d.clone() + + # If cond is TRUE, then the corresponding ray has not hit yet. + # OR, the corresponding ray has exit the clipping plane. + #cond = torch.ones_like(d).bool()[:,0] + + # If miss is TRUE, then the corresponding ray has missed entirely. + hit = torch.zeros_like(d).byte() + + for i in range(self.num_steps): + timer.check("start") + # 1. Check if ray hits. + #hit = (torch.abs(d) < self._MIN_DIS)[:,0] + # 2. Check that the sphere tracing is not oscillating + #hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0] + + # 3. Check that the ray has not exit the far clipping plane. + #cond = (torch.abs(t) < self.clamp[1])[:,0] + + hit = (torch.abs(t) < self.camera_clamp[1])[:,0] + + # 1. not hit surface + cond = cond & (torch.abs(d) > self.min_dis)[:,0] + + # 2. not oscillating + cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] + + # 3. not a hit + cond = cond & hit + + #cond = cond & ~hit + + # If the sum is 0, that means that all rays have hit, or missed. + if not cond.any(): + break + + # Advance the x, by updating with a new t + x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) + + # Store the previous distance + dprev = torch.where(cond.unsqueeze(1), d, dprev) + + nettimer.check("nstart") + # Update the distance to surface at x + d[cond] = net(x[cond]) * self.step_size + + nettimer.check("nend") + + # Update the distance from origin + t = torch.where(cond.view(cond.shape[0], 1), t+d, t) + timer.check("end") + + # AABB cull + + hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) + + # The function will return + # x: the final model-space coordinate of the render + # t: the final distance from origin + # d: the final distance value from + # miss: a vector containing bools of whether each ray was a hit or miss + #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) + + grad = gradient(x[hit], net, method=self.grad_method) + _normal = F.normalize(grad, p=2, dim=-1, eps=1e-5) + normal[hit] = _normal + + return RenderBuffer(x=x, depth=t, hit=hit, normal=normal) + + def get_min(self, net, ray_o, ray_d): + + timer = PerfTimer(activate=False) + nettimer = PerfTimer(activate=False) + + # Distance from ray origin + t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) + + # Position in model space + x = torch.addcmul(ray_o, ray_d, t) + + x, t, hit = aabb(ray_o, ray_d); + + normal = torch.zeros_like(x) + + with torch.no_grad(): + d = net(x) + dprev = d.clone() + mind = d.clone() + minx = x.clone() + + # If cond is TRUE, then the corresponding ray has not hit yet. + # OR, the corresponding ray has exit the clipping plane. + cond = torch.ones_like(d).bool()[:,0] + + # If miss is TRUE, then the corresponding ray has missed entirely. + hit = torch.zeros_like(d).byte() + + for i in range(self.num_steps): + timer.check("start") + + hit = (torch.abs(t) < self.camera_clamp[1])[:,0] + + # 1. not hit surface + cond = (torch.abs(d) > self.min_dis)[:,0] + + # 2. not oscillating + cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] + + # 3. not a hit + cond = cond & hit + + + #cond = cond & ~hit + + # If the sum is 0, that means that all rays have hit, or missed. + if not cond.any(): + break + + # Advance the x, by updating with a new t + x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) + + new_mins = (d 1.0).any(dim=-1) + #hit = torch.ones_like(d).byte()[...,0] + + # The function will return + # x: the final model-space coordinate of the render + # t: the final distance from origin + # d: the final distance value from + # miss: a vector containing bools of whether each ray was a hit or miss + #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) + _normal = gradient(x[hit], net, method=self.grad_method) + normal[hit] = _normal + + + return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx) + + def sample_surface(self, n, net): + + # Sample surface using random tracing (resample until num_samples is reached) + + timer = PerfTimer(activate=True) + + with torch.no_grad(): + i = 0 + while i < 1000: + ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0 + # this really should just return a torch array in the first place + ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device) + rb = self.forward(net, ray_o, ray_d) + + #d = torch.abs(net(rb.x)[..., 0]) + #idx = torch.where(d < 0.0003) + #pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0) + + pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0) + if pts_pr.shape[0] >= n: + break + i += 1 + if i == 50: + print('Taking an unusually long time to sample desired # of points.') + + return pts_pr +