Browse Source
- Implemented `BaseTracer` class with configurable tracing parameters - Created `RenderBuffer` dataclass for flexible rendering data management - Added `SphereTracer` with advanced ray marching and surface tracing methods - Introduced utility functions for parameter setting and gradient computation - Supported various tensor operations and transformations in render bufferNH-Rep
3 changed files with 456 additions and 0 deletions
@ -0,0 +1,58 @@ |
|||||
|
# The MIT License (MIT) |
||||
|
# |
||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. |
||||
|
# |
||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
|
# this software and associated documentation files (the "Software"), to deal in |
||||
|
# the Software without restriction, including without limitation the rights to |
||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
|
# the Software, and to permit persons to whom the Software is furnished to do so, |
||||
|
# subject to the following conditions: |
||||
|
# |
||||
|
# The above copyright notice and this permission notice shall be included in all |
||||
|
# copies or substantial portions of the Software. |
||||
|
# |
||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
# This is the bridge between an argparse based approach and a non-argparse one |
||||
|
def setparam(args, param, paramstr): |
||||
|
argsparam = getattr(args, paramstr, None) |
||||
|
if param is not None or argsparam is None: |
||||
|
return param |
||||
|
else: |
||||
|
return argsparam |
||||
|
|
||||
|
class BaseTracer(object): |
||||
|
"""Virtual base class for tracer""" |
||||
|
|
||||
|
def __init__(self, |
||||
|
args = None, |
||||
|
camera_clamp : list = None, |
||||
|
step_size : float = None, |
||||
|
grad_method : str = None, |
||||
|
num_steps : int = None, # samples for raymaching, iterations for sphere trace |
||||
|
min_dis : float = None): |
||||
|
|
||||
|
self.args = args |
||||
|
self.camera_clamp = setparam(args, camera_clamp, 'camera_clamp') |
||||
|
self.step_size = setparam(args, step_size, 'step_size') |
||||
|
self.grad_method = setparam(args, grad_method, 'grad_method') |
||||
|
self.num_steps = setparam(args, num_steps, 'num_steps') |
||||
|
self.min_dis = setparam(args, min_dis, 'min_dis') |
||||
|
|
||||
|
self.inv_num_steps = 1.0 / self.num_steps |
||||
|
self.diagonal = np.sqrt(3) * 2.0 |
||||
|
|
||||
|
def __call__(self, *args, **kwargs): |
||||
|
return self.forward(*args, **kwargs) |
||||
|
|
||||
|
def forward(self, net, ray_o, ray_d): |
||||
|
"""Base implementation for forward""" |
||||
|
raise NotImplementedError |
@ -0,0 +1,152 @@ |
|||||
|
# The MIT License (MIT) |
||||
|
# |
||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. |
||||
|
# |
||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
|
# this software and associated documentation files (the "Software"), to deal in |
||||
|
# the Software without restriction, including without limitation the rights to |
||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
|
# the Software, and to permit persons to whom the Software is furnished to do so, |
||||
|
# subject to the following conditions: |
||||
|
# |
||||
|
# The above copyright notice and this permission notice shall be included in all |
||||
|
# copies or substantial portions of the Software. |
||||
|
# |
||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
||||
|
|
||||
|
from dataclasses import asdict, astuple, fields, dataclass |
||||
|
from typing import Union, List |
||||
|
|
||||
|
import torch |
||||
|
import numpy as np |
||||
|
|
||||
|
@dataclass |
||||
|
class RenderBuffer: |
||||
|
x : Union[torch.Tensor, None] = None |
||||
|
min_x : Union[torch.Tensor, None] = None |
||||
|
hit : Union[torch.Tensor, None] = None |
||||
|
depth : Union[torch.Tensor, None] = None |
||||
|
relative_depth : Union[torch.Tensor, None] = None |
||||
|
normal : Union[torch.Tensor, None] = None |
||||
|
rgb : Union[torch.Tensor, None] = None |
||||
|
shadow : Union[torch.Tensor, None] = None |
||||
|
ao : Union[torch.Tensor, None] = None |
||||
|
view : Union[torch.Tensor, None] = None |
||||
|
err : Union[torch.Tensor, None] = None |
||||
|
albedo : Union[torch.Tensor, None] = None |
||||
|
|
||||
|
def __iter__(self): |
||||
|
return iter(astuple(self)) |
||||
|
|
||||
|
def __add__(self, other): |
||||
|
def _proc(pair): |
||||
|
if None not in pair: |
||||
|
return torch.cat(pair) |
||||
|
elif pair[0] is not None and pair[1] is None: |
||||
|
return pair[0] |
||||
|
elif pair[0] is None and pair[1] is not None: |
||||
|
return pair[1] |
||||
|
else: |
||||
|
return None |
||||
|
|
||||
|
return self.__class__(*(map(_proc, zip(self, other)))) |
||||
|
|
||||
|
def _apply(self, fn): |
||||
|
data = {} |
||||
|
for f in fields(self): |
||||
|
attr = getattr(self, f.name) |
||||
|
data[f.name] = None if attr is None else fn(attr) |
||||
|
return self.__class__(**data) |
||||
|
|
||||
|
def cuda(self): |
||||
|
fn = lambda x : x.cuda() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def cpu(self): |
||||
|
fn = lambda x : x.cpu() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def detach(self): |
||||
|
fn = lambda x : x.detach() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def byte(self): |
||||
|
fn = lambda x : x.byte() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def reshape(self, *dims : List[int]): |
||||
|
fn = lambda x : x.reshape(*dims) |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def transpose(self): |
||||
|
fn = lambda x : x.permute(1,0,2) |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def numpy(self): |
||||
|
fn = lambda x : x.numpy() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def float(self): |
||||
|
fn = lambda x : x.float() |
||||
|
return self._apply(fn) |
||||
|
|
||||
|
def exrdict(self): |
||||
|
_dict = asdict(self) |
||||
|
_dict = {k:v for k,v in _dict.items() if v is not None} |
||||
|
if 'rgb' in _dict: |
||||
|
_dict['default'] = _dict['rgb'] |
||||
|
del _dict['rgb'] |
||||
|
return _dict |
||||
|
|
||||
|
def image(self): |
||||
|
# Unfinished |
||||
|
norm = lambda arr : ((arr + 1.0) / 2.0) if arr is not None else None |
||||
|
bwrgb = lambda arr : torch.cat([arr]*3, dim=-1) if arr is not None else None |
||||
|
rgb8 = lambda arr : (arr * 255.0) if arr is not None else None |
||||
|
|
||||
|
#x = rgb8(norm(self.x)) |
||||
|
#min_x = rgb8(norm(self.min_x)) |
||||
|
hit = rgb8(bwrgb(self.hit)) |
||||
|
depth = rgb8(bwrgb(self.relative_depth)) |
||||
|
normal = rgb8(norm(self.normal)) |
||||
|
#ao = 1.0 - rgb8(bwrgb(self.ao)) |
||||
|
rgb = rgb8(self.rgb) |
||||
|
#return self.__class__(x=x, min_x=min_x, hit=hit, depth=depth, normal=normal, ao=ao, rgb=rgb) |
||||
|
return self.__class__(hit=hit, normal=normal, rgb=rgb, depth=depth) |
||||
|
|
||||
|
@staticmethod |
||||
|
def mean(*rblst): |
||||
|
rb = RenderBuffer() |
||||
|
n = len(rblst) |
||||
|
for _rb in rblst: |
||||
|
for f in fields(_rb): |
||||
|
attr = getattr(_rb, f.name) |
||||
|
dest = getattr(rb, f.name) |
||||
|
if attr is not None and dest is not None: |
||||
|
setattr(rb, f.name, dest + attr) |
||||
|
elif dest is None: |
||||
|
setattr(rb, f.name, attr) |
||||
|
|
||||
|
for f in fields(rb): |
||||
|
dest = getattr(rb, f.name) |
||||
|
if dest is not None: |
||||
|
setattr(rb, f.name, dest / float(n)) |
||||
|
return rb |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
hw = 1024 |
||||
|
hit = torch.zeros(hw, hw, 1).bool() |
||||
|
rgb = torch.zeros(hw, hw, 3).float() |
||||
|
|
||||
|
rb0 = RenderBuffer(hit=hit.clone(), rgb=rgb.clone()) |
||||
|
rb1 = RenderBuffer(hit=hit, rgb=rgb) |
||||
|
rb1.rgb += 1.0 |
||||
|
|
||||
|
avg = RenderBuffer.mean(rb0, rb1) |
||||
|
|
||||
|
import pdb; pdb.set_trace() |
@ -0,0 +1,246 @@ |
|||||
|
# The MIT License (MIT) |
||||
|
# |
||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. |
||||
|
# |
||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of |
||||
|
# this software and associated documentation files (the "Software"), to deal in |
||||
|
# the Software without restriction, including without limitation the rights to |
||||
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
||||
|
# the Software, and to permit persons to whom the Software is furnished to do so, |
||||
|
# subject to the following conditions: |
||||
|
# |
||||
|
# The above copyright notice and this permission notice shall be included in all |
||||
|
# copies or substantial portions of the Software. |
||||
|
# |
||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
||||
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
||||
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
||||
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
||||
|
|
||||
|
|
||||
|
import torch |
||||
|
import torch.nn.functional as F |
||||
|
import torch.nn as nn |
||||
|
import numpy as np |
||||
|
|
||||
|
from lib.utils import PerfTimer |
||||
|
from lib.diffutils import gradient |
||||
|
from lib.geoutils import sample_unif_sphere |
||||
|
from lib.tracer.RenderBuffer import RenderBuffer |
||||
|
|
||||
|
from lib.PsDebugger import PsDebugger |
||||
|
|
||||
|
from base_tracer import BaseTracer |
||||
|
|
||||
|
from sol_nglod import aabb |
||||
|
|
||||
|
class SphereTracer(BaseTracer): |
||||
|
|
||||
|
def forward(self, net, ray_o, ray_d): |
||||
|
"""Native implementation of sphere tracing.""" |
||||
|
timer = PerfTimer(activate=False) |
||||
|
nettimer = PerfTimer(activate=False) |
||||
|
|
||||
|
# Distanace from ray origin |
||||
|
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) |
||||
|
|
||||
|
# Position in model space |
||||
|
x = torch.addcmul(ray_o, ray_d, t) |
||||
|
|
||||
|
cond = torch.ones_like(t).bool()[:,0] |
||||
|
x, t, cond = aabb(ray_o, ray_d) |
||||
|
|
||||
|
normal = torch.zeros_like(x) |
||||
|
# This function is in fact differentiable, but we treat it as if it's not, because |
||||
|
# it evaluates a very long chain of recursive neural networks (essentially a NN with depth of |
||||
|
# ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit |
||||
|
# locations, where additional quantities (normal, depth, segmentation) can be determined. The |
||||
|
# gradients will propagate only to these locations. |
||||
|
with torch.no_grad(): |
||||
|
|
||||
|
d = net(x) |
||||
|
|
||||
|
dprev = d.clone() |
||||
|
|
||||
|
# If cond is TRUE, then the corresponding ray has not hit yet. |
||||
|
# OR, the corresponding ray has exit the clipping plane. |
||||
|
#cond = torch.ones_like(d).bool()[:,0] |
||||
|
|
||||
|
# If miss is TRUE, then the corresponding ray has missed entirely. |
||||
|
hit = torch.zeros_like(d).byte() |
||||
|
|
||||
|
for i in range(self.num_steps): |
||||
|
timer.check("start") |
||||
|
# 1. Check if ray hits. |
||||
|
#hit = (torch.abs(d) < self._MIN_DIS)[:,0] |
||||
|
# 2. Check that the sphere tracing is not oscillating |
||||
|
#hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0] |
||||
|
|
||||
|
# 3. Check that the ray has not exit the far clipping plane. |
||||
|
#cond = (torch.abs(t) < self.clamp[1])[:,0] |
||||
|
|
||||
|
hit = (torch.abs(t) < self.camera_clamp[1])[:,0] |
||||
|
|
||||
|
# 1. not hit surface |
||||
|
cond = cond & (torch.abs(d) > self.min_dis)[:,0] |
||||
|
|
||||
|
# 2. not oscillating |
||||
|
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] |
||||
|
|
||||
|
# 3. not a hit |
||||
|
cond = cond & hit |
||||
|
|
||||
|
#cond = cond & ~hit |
||||
|
|
||||
|
# If the sum is 0, that means that all rays have hit, or missed. |
||||
|
if not cond.any(): |
||||
|
break |
||||
|
|
||||
|
# Advance the x, by updating with a new t |
||||
|
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) |
||||
|
|
||||
|
# Store the previous distance |
||||
|
dprev = torch.where(cond.unsqueeze(1), d, dprev) |
||||
|
|
||||
|
nettimer.check("nstart") |
||||
|
# Update the distance to surface at x |
||||
|
d[cond] = net(x[cond]) * self.step_size |
||||
|
|
||||
|
nettimer.check("nend") |
||||
|
|
||||
|
# Update the distance from origin |
||||
|
t = torch.where(cond.view(cond.shape[0], 1), t+d, t) |
||||
|
timer.check("end") |
||||
|
|
||||
|
# AABB cull |
||||
|
|
||||
|
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) |
||||
|
|
||||
|
# The function will return |
||||
|
# x: the final model-space coordinate of the render |
||||
|
# t: the final distance from origin |
||||
|
# d: the final distance value from |
||||
|
# miss: a vector containing bools of whether each ray was a hit or miss |
||||
|
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) |
||||
|
|
||||
|
grad = gradient(x[hit], net, method=self.grad_method) |
||||
|
_normal = F.normalize(grad, p=2, dim=-1, eps=1e-5) |
||||
|
normal[hit] = _normal |
||||
|
|
||||
|
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal) |
||||
|
|
||||
|
def get_min(self, net, ray_o, ray_d): |
||||
|
|
||||
|
timer = PerfTimer(activate=False) |
||||
|
nettimer = PerfTimer(activate=False) |
||||
|
|
||||
|
# Distance from ray origin |
||||
|
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) |
||||
|
|
||||
|
# Position in model space |
||||
|
x = torch.addcmul(ray_o, ray_d, t) |
||||
|
|
||||
|
x, t, hit = aabb(ray_o, ray_d); |
||||
|
|
||||
|
normal = torch.zeros_like(x) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
d = net(x) |
||||
|
dprev = d.clone() |
||||
|
mind = d.clone() |
||||
|
minx = x.clone() |
||||
|
|
||||
|
# If cond is TRUE, then the corresponding ray has not hit yet. |
||||
|
# OR, the corresponding ray has exit the clipping plane. |
||||
|
cond = torch.ones_like(d).bool()[:,0] |
||||
|
|
||||
|
# If miss is TRUE, then the corresponding ray has missed entirely. |
||||
|
hit = torch.zeros_like(d).byte() |
||||
|
|
||||
|
for i in range(self.num_steps): |
||||
|
timer.check("start") |
||||
|
|
||||
|
hit = (torch.abs(t) < self.camera_clamp[1])[:,0] |
||||
|
|
||||
|
# 1. not hit surface |
||||
|
cond = (torch.abs(d) > self.min_dis)[:,0] |
||||
|
|
||||
|
# 2. not oscillating |
||||
|
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] |
||||
|
|
||||
|
# 3. not a hit |
||||
|
cond = cond & hit |
||||
|
|
||||
|
|
||||
|
#cond = cond & ~hit |
||||
|
|
||||
|
# If the sum is 0, that means that all rays have hit, or missed. |
||||
|
if not cond.any(): |
||||
|
break |
||||
|
|
||||
|
# Advance the x, by updating with a new t |
||||
|
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) |
||||
|
|
||||
|
new_mins = (d<mind)[...,0] |
||||
|
mind[new_mins] = d[new_mins] |
||||
|
minx[new_mins] = x[new_mins] |
||||
|
|
||||
|
# Store the previous distance |
||||
|
dprev = torch.where(cond.unsqueeze(1), d, dprev) |
||||
|
|
||||
|
nettimer.check("nstart") |
||||
|
# Update the distance to surface at x |
||||
|
d[cond] = net(x[cond]) * self.step_size |
||||
|
|
||||
|
nettimer.check("nend") |
||||
|
|
||||
|
# Update the distance from origin |
||||
|
t = torch.where(cond.view(cond.shape[0], 1), t+d, t) |
||||
|
timer.check("end") |
||||
|
|
||||
|
# AABB cull |
||||
|
|
||||
|
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) |
||||
|
#hit = torch.ones_like(d).byte()[...,0] |
||||
|
|
||||
|
# The function will return |
||||
|
# x: the final model-space coordinate of the render |
||||
|
# t: the final distance from origin |
||||
|
# d: the final distance value from |
||||
|
# miss: a vector containing bools of whether each ray was a hit or miss |
||||
|
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) |
||||
|
_normal = gradient(x[hit], net, method=self.grad_method) |
||||
|
normal[hit] = _normal |
||||
|
|
||||
|
|
||||
|
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx) |
||||
|
|
||||
|
def sample_surface(self, n, net): |
||||
|
|
||||
|
# Sample surface using random tracing (resample until num_samples is reached) |
||||
|
|
||||
|
timer = PerfTimer(activate=True) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
i = 0 |
||||
|
while i < 1000: |
||||
|
ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0 |
||||
|
# this really should just return a torch array in the first place |
||||
|
ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device) |
||||
|
rb = self.forward(net, ray_o, ray_d) |
||||
|
|
||||
|
#d = torch.abs(net(rb.x)[..., 0]) |
||||
|
#idx = torch.where(d < 0.0003) |
||||
|
#pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0) |
||||
|
|
||||
|
pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0) |
||||
|
if pts_pr.shape[0] >= n: |
||||
|
break |
||||
|
i += 1 |
||||
|
if i == 50: |
||||
|
print('Taking an unusually long time to sample desired # of points.') |
||||
|
|
||||
|
return pts_pr |
||||
|
|
Loading…
Reference in new issue