Browse Source

feat: Add base tracer and rendering utilities for ray tracing

- Implemented `BaseTracer` class with configurable tracing parameters
- Created `RenderBuffer` dataclass for flexible rendering data management
- Added `SphereTracer` with advanced ray marching and surface tracing methods
- Introduced utility functions for parameter setting and gradient computation
- Supported various tensor operations and transformations in render buffer
NH-Rep
mckay 3 months ago
parent
commit
7efcfaae0c
  1. 58
      code/tracer/base_tracer.py
  2. 152
      code/tracer/render_buffer.py
  3. 246
      code/tracer/sphere_tracer.py

58
code/tracer/base_tracer.py

@ -0,0 +1,58 @@
# The MIT License (MIT)
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np
# This is the bridge between an argparse based approach and a non-argparse one
def setparam(args, param, paramstr):
argsparam = getattr(args, paramstr, None)
if param is not None or argsparam is None:
return param
else:
return argsparam
class BaseTracer(object):
"""Virtual base class for tracer"""
def __init__(self,
args = None,
camera_clamp : list = None,
step_size : float = None,
grad_method : str = None,
num_steps : int = None, # samples for raymaching, iterations for sphere trace
min_dis : float = None):
self.args = args
self.camera_clamp = setparam(args, camera_clamp, 'camera_clamp')
self.step_size = setparam(args, step_size, 'step_size')
self.grad_method = setparam(args, grad_method, 'grad_method')
self.num_steps = setparam(args, num_steps, 'num_steps')
self.min_dis = setparam(args, min_dis, 'min_dis')
self.inv_num_steps = 1.0 / self.num_steps
self.diagonal = np.sqrt(3) * 2.0
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, net, ray_o, ray_d):
"""Base implementation for forward"""
raise NotImplementedError

152
code/tracer/render_buffer.py

@ -0,0 +1,152 @@
# The MIT License (MIT)
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from dataclasses import asdict, astuple, fields, dataclass
from typing import Union, List
import torch
import numpy as np
@dataclass
class RenderBuffer:
x : Union[torch.Tensor, None] = None
min_x : Union[torch.Tensor, None] = None
hit : Union[torch.Tensor, None] = None
depth : Union[torch.Tensor, None] = None
relative_depth : Union[torch.Tensor, None] = None
normal : Union[torch.Tensor, None] = None
rgb : Union[torch.Tensor, None] = None
shadow : Union[torch.Tensor, None] = None
ao : Union[torch.Tensor, None] = None
view : Union[torch.Tensor, None] = None
err : Union[torch.Tensor, None] = None
albedo : Union[torch.Tensor, None] = None
def __iter__(self):
return iter(astuple(self))
def __add__(self, other):
def _proc(pair):
if None not in pair:
return torch.cat(pair)
elif pair[0] is not None and pair[1] is None:
return pair[0]
elif pair[0] is None and pair[1] is not None:
return pair[1]
else:
return None
return self.__class__(*(map(_proc, zip(self, other))))
def _apply(self, fn):
data = {}
for f in fields(self):
attr = getattr(self, f.name)
data[f.name] = None if attr is None else fn(attr)
return self.__class__(**data)
def cuda(self):
fn = lambda x : x.cuda()
return self._apply(fn)
def cpu(self):
fn = lambda x : x.cpu()
return self._apply(fn)
def detach(self):
fn = lambda x : x.detach()
return self._apply(fn)
def byte(self):
fn = lambda x : x.byte()
return self._apply(fn)
def reshape(self, *dims : List[int]):
fn = lambda x : x.reshape(*dims)
return self._apply(fn)
def transpose(self):
fn = lambda x : x.permute(1,0,2)
return self._apply(fn)
def numpy(self):
fn = lambda x : x.numpy()
return self._apply(fn)
def float(self):
fn = lambda x : x.float()
return self._apply(fn)
def exrdict(self):
_dict = asdict(self)
_dict = {k:v for k,v in _dict.items() if v is not None}
if 'rgb' in _dict:
_dict['default'] = _dict['rgb']
del _dict['rgb']
return _dict
def image(self):
# Unfinished
norm = lambda arr : ((arr + 1.0) / 2.0) if arr is not None else None
bwrgb = lambda arr : torch.cat([arr]*3, dim=-1) if arr is not None else None
rgb8 = lambda arr : (arr * 255.0) if arr is not None else None
#x = rgb8(norm(self.x))
#min_x = rgb8(norm(self.min_x))
hit = rgb8(bwrgb(self.hit))
depth = rgb8(bwrgb(self.relative_depth))
normal = rgb8(norm(self.normal))
#ao = 1.0 - rgb8(bwrgb(self.ao))
rgb = rgb8(self.rgb)
#return self.__class__(x=x, min_x=min_x, hit=hit, depth=depth, normal=normal, ao=ao, rgb=rgb)
return self.__class__(hit=hit, normal=normal, rgb=rgb, depth=depth)
@staticmethod
def mean(*rblst):
rb = RenderBuffer()
n = len(rblst)
for _rb in rblst:
for f in fields(_rb):
attr = getattr(_rb, f.name)
dest = getattr(rb, f.name)
if attr is not None and dest is not None:
setattr(rb, f.name, dest + attr)
elif dest is None:
setattr(rb, f.name, attr)
for f in fields(rb):
dest = getattr(rb, f.name)
if dest is not None:
setattr(rb, f.name, dest / float(n))
return rb
if __name__ == '__main__':
hw = 1024
hit = torch.zeros(hw, hw, 1).bool()
rgb = torch.zeros(hw, hw, 3).float()
rb0 = RenderBuffer(hit=hit.clone(), rgb=rgb.clone())
rb1 = RenderBuffer(hit=hit, rgb=rgb)
rb1.rgb += 1.0
avg = RenderBuffer.mean(rb0, rb1)
import pdb; pdb.set_trace()

246
code/tracer/sphere_tracer.py

@ -0,0 +1,246 @@
# The MIT License (MIT)
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from lib.utils import PerfTimer
from lib.diffutils import gradient
from lib.geoutils import sample_unif_sphere
from lib.tracer.RenderBuffer import RenderBuffer
from lib.PsDebugger import PsDebugger
from base_tracer import BaseTracer
from sol_nglod import aabb
class SphereTracer(BaseTracer):
def forward(self, net, ray_o, ray_d):
"""Native implementation of sphere tracing."""
timer = PerfTimer(activate=False)
nettimer = PerfTimer(activate=False)
# Distanace from ray origin
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device)
# Position in model space
x = torch.addcmul(ray_o, ray_d, t)
cond = torch.ones_like(t).bool()[:,0]
x, t, cond = aabb(ray_o, ray_d)
normal = torch.zeros_like(x)
# This function is in fact differentiable, but we treat it as if it's not, because
# it evaluates a very long chain of recursive neural networks (essentially a NN with depth of
# ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit
# locations, where additional quantities (normal, depth, segmentation) can be determined. The
# gradients will propagate only to these locations.
with torch.no_grad():
d = net(x)
dprev = d.clone()
# If cond is TRUE, then the corresponding ray has not hit yet.
# OR, the corresponding ray has exit the clipping plane.
#cond = torch.ones_like(d).bool()[:,0]
# If miss is TRUE, then the corresponding ray has missed entirely.
hit = torch.zeros_like(d).byte()
for i in range(self.num_steps):
timer.check("start")
# 1. Check if ray hits.
#hit = (torch.abs(d) < self._MIN_DIS)[:,0]
# 2. Check that the sphere tracing is not oscillating
#hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0]
# 3. Check that the ray has not exit the far clipping plane.
#cond = (torch.abs(t) < self.clamp[1])[:,0]
hit = (torch.abs(t) < self.camera_clamp[1])[:,0]
# 1. not hit surface
cond = cond & (torch.abs(d) > self.min_dis)[:,0]
# 2. not oscillating
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0]
# 3. not a hit
cond = cond & hit
#cond = cond & ~hit
# If the sum is 0, that means that all rays have hit, or missed.
if not cond.any():
break
# Advance the x, by updating with a new t
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x)
# Store the previous distance
dprev = torch.where(cond.unsqueeze(1), d, dprev)
nettimer.check("nstart")
# Update the distance to surface at x
d[cond] = net(x[cond]) * self.step_size
nettimer.check("nend")
# Update the distance from origin
t = torch.where(cond.view(cond.shape[0], 1), t+d, t)
timer.check("end")
# AABB cull
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1)
# The function will return
# x: the final model-space coordinate of the render
# t: the final distance from origin
# d: the final distance value from
# miss: a vector containing bools of whether each ray was a hit or miss
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5)
grad = gradient(x[hit], net, method=self.grad_method)
_normal = F.normalize(grad, p=2, dim=-1, eps=1e-5)
normal[hit] = _normal
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal)
def get_min(self, net, ray_o, ray_d):
timer = PerfTimer(activate=False)
nettimer = PerfTimer(activate=False)
# Distance from ray origin
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device)
# Position in model space
x = torch.addcmul(ray_o, ray_d, t)
x, t, hit = aabb(ray_o, ray_d);
normal = torch.zeros_like(x)
with torch.no_grad():
d = net(x)
dprev = d.clone()
mind = d.clone()
minx = x.clone()
# If cond is TRUE, then the corresponding ray has not hit yet.
# OR, the corresponding ray has exit the clipping plane.
cond = torch.ones_like(d).bool()[:,0]
# If miss is TRUE, then the corresponding ray has missed entirely.
hit = torch.zeros_like(d).byte()
for i in range(self.num_steps):
timer.check("start")
hit = (torch.abs(t) < self.camera_clamp[1])[:,0]
# 1. not hit surface
cond = (torch.abs(d) > self.min_dis)[:,0]
# 2. not oscillating
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0]
# 3. not a hit
cond = cond & hit
#cond = cond & ~hit
# If the sum is 0, that means that all rays have hit, or missed.
if not cond.any():
break
# Advance the x, by updating with a new t
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x)
new_mins = (d<mind)[...,0]
mind[new_mins] = d[new_mins]
minx[new_mins] = x[new_mins]
# Store the previous distance
dprev = torch.where(cond.unsqueeze(1), d, dprev)
nettimer.check("nstart")
# Update the distance to surface at x
d[cond] = net(x[cond]) * self.step_size
nettimer.check("nend")
# Update the distance from origin
t = torch.where(cond.view(cond.shape[0], 1), t+d, t)
timer.check("end")
# AABB cull
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1)
#hit = torch.ones_like(d).byte()[...,0]
# The function will return
# x: the final model-space coordinate of the render
# t: the final distance from origin
# d: the final distance value from
# miss: a vector containing bools of whether each ray was a hit or miss
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5)
_normal = gradient(x[hit], net, method=self.grad_method)
normal[hit] = _normal
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx)
def sample_surface(self, n, net):
# Sample surface using random tracing (resample until num_samples is reached)
timer = PerfTimer(activate=True)
with torch.no_grad():
i = 0
while i < 1000:
ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0
# this really should just return a torch array in the first place
ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device)
rb = self.forward(net, ray_o, ray_d)
#d = torch.abs(net(rb.x)[..., 0])
#idx = torch.where(d < 0.0003)
#pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0)
pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0)
if pts_pr.shape[0] >= n:
break
i += 1
if i == 50:
print('Taking an unusually long time to sample desired # of points.')
return pts_pr
Loading…
Cancel
Save