You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
9.2 KiB
246 lines
9.2 KiB
# The MIT License (MIT)
|
|
#
|
|
# Copyright (c) 2021, NVIDIA CORPORATION.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
# this software and associated documentation files (the "Software"), to deal in
|
|
# the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from lib.utils import PerfTimer
|
|
from lib.diffutils import gradient
|
|
from lib.geoutils import sample_unif_sphere
|
|
from lib.tracer.RenderBuffer import RenderBuffer
|
|
|
|
from lib.PsDebugger import PsDebugger
|
|
|
|
from base_tracer import BaseTracer
|
|
|
|
from sol_nglod import aabb
|
|
|
|
class SphereTracer(BaseTracer):
|
|
|
|
def forward(self, net, ray_o, ray_d):
|
|
"""Native implementation of sphere tracing."""
|
|
timer = PerfTimer(activate=False)
|
|
nettimer = PerfTimer(activate=False)
|
|
|
|
# Distanace from ray origin
|
|
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device)
|
|
|
|
# Position in model space
|
|
x = torch.addcmul(ray_o, ray_d, t)
|
|
|
|
cond = torch.ones_like(t).bool()[:,0]
|
|
x, t, cond = aabb(ray_o, ray_d)
|
|
|
|
normal = torch.zeros_like(x)
|
|
# This function is in fact differentiable, but we treat it as if it's not, because
|
|
# it evaluates a very long chain of recursive neural networks (essentially a NN with depth of
|
|
# ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit
|
|
# locations, where additional quantities (normal, depth, segmentation) can be determined. The
|
|
# gradients will propagate only to these locations.
|
|
with torch.no_grad():
|
|
|
|
d = net(x)
|
|
|
|
dprev = d.clone()
|
|
|
|
# If cond is TRUE, then the corresponding ray has not hit yet.
|
|
# OR, the corresponding ray has exit the clipping plane.
|
|
#cond = torch.ones_like(d).bool()[:,0]
|
|
|
|
# If miss is TRUE, then the corresponding ray has missed entirely.
|
|
hit = torch.zeros_like(d).byte()
|
|
|
|
for i in range(self.num_steps):
|
|
timer.check("start")
|
|
# 1. Check if ray hits.
|
|
#hit = (torch.abs(d) < self._MIN_DIS)[:,0]
|
|
# 2. Check that the sphere tracing is not oscillating
|
|
#hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0]
|
|
|
|
# 3. Check that the ray has not exit the far clipping plane.
|
|
#cond = (torch.abs(t) < self.clamp[1])[:,0]
|
|
|
|
hit = (torch.abs(t) < self.camera_clamp[1])[:,0]
|
|
|
|
# 1. not hit surface
|
|
cond = cond & (torch.abs(d) > self.min_dis)[:,0]
|
|
|
|
# 2. not oscillating
|
|
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0]
|
|
|
|
# 3. not a hit
|
|
cond = cond & hit
|
|
|
|
#cond = cond & ~hit
|
|
|
|
# If the sum is 0, that means that all rays have hit, or missed.
|
|
if not cond.any():
|
|
break
|
|
|
|
# Advance the x, by updating with a new t
|
|
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x)
|
|
|
|
# Store the previous distance
|
|
dprev = torch.where(cond.unsqueeze(1), d, dprev)
|
|
|
|
nettimer.check("nstart")
|
|
# Update the distance to surface at x
|
|
d[cond] = net(x[cond]) * self.step_size
|
|
|
|
nettimer.check("nend")
|
|
|
|
# Update the distance from origin
|
|
t = torch.where(cond.view(cond.shape[0], 1), t+d, t)
|
|
timer.check("end")
|
|
|
|
# AABB cull
|
|
|
|
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1)
|
|
|
|
# The function will return
|
|
# x: the final model-space coordinate of the render
|
|
# t: the final distance from origin
|
|
# d: the final distance value from
|
|
# miss: a vector containing bools of whether each ray was a hit or miss
|
|
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5)
|
|
|
|
grad = gradient(x[hit], net, method=self.grad_method)
|
|
_normal = F.normalize(grad, p=2, dim=-1, eps=1e-5)
|
|
normal[hit] = _normal
|
|
|
|
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal)
|
|
|
|
def get_min(self, net, ray_o, ray_d):
|
|
|
|
timer = PerfTimer(activate=False)
|
|
nettimer = PerfTimer(activate=False)
|
|
|
|
# Distance from ray origin
|
|
t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device)
|
|
|
|
# Position in model space
|
|
x = torch.addcmul(ray_o, ray_d, t)
|
|
|
|
x, t, hit = aabb(ray_o, ray_d);
|
|
|
|
normal = torch.zeros_like(x)
|
|
|
|
with torch.no_grad():
|
|
d = net(x)
|
|
dprev = d.clone()
|
|
mind = d.clone()
|
|
minx = x.clone()
|
|
|
|
# If cond is TRUE, then the corresponding ray has not hit yet.
|
|
# OR, the corresponding ray has exit the clipping plane.
|
|
cond = torch.ones_like(d).bool()[:,0]
|
|
|
|
# If miss is TRUE, then the corresponding ray has missed entirely.
|
|
hit = torch.zeros_like(d).byte()
|
|
|
|
for i in range(self.num_steps):
|
|
timer.check("start")
|
|
|
|
hit = (torch.abs(t) < self.camera_clamp[1])[:,0]
|
|
|
|
# 1. not hit surface
|
|
cond = (torch.abs(d) > self.min_dis)[:,0]
|
|
|
|
# 2. not oscillating
|
|
cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0]
|
|
|
|
# 3. not a hit
|
|
cond = cond & hit
|
|
|
|
|
|
#cond = cond & ~hit
|
|
|
|
# If the sum is 0, that means that all rays have hit, or missed.
|
|
if not cond.any():
|
|
break
|
|
|
|
# Advance the x, by updating with a new t
|
|
x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x)
|
|
|
|
new_mins = (d<mind)[...,0]
|
|
mind[new_mins] = d[new_mins]
|
|
minx[new_mins] = x[new_mins]
|
|
|
|
# Store the previous distance
|
|
dprev = torch.where(cond.unsqueeze(1), d, dprev)
|
|
|
|
nettimer.check("nstart")
|
|
# Update the distance to surface at x
|
|
d[cond] = net(x[cond]) * self.step_size
|
|
|
|
nettimer.check("nend")
|
|
|
|
# Update the distance from origin
|
|
t = torch.where(cond.view(cond.shape[0], 1), t+d, t)
|
|
timer.check("end")
|
|
|
|
# AABB cull
|
|
|
|
hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1)
|
|
#hit = torch.ones_like(d).byte()[...,0]
|
|
|
|
# The function will return
|
|
# x: the final model-space coordinate of the render
|
|
# t: the final distance from origin
|
|
# d: the final distance value from
|
|
# miss: a vector containing bools of whether each ray was a hit or miss
|
|
#_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5)
|
|
_normal = gradient(x[hit], net, method=self.grad_method)
|
|
normal[hit] = _normal
|
|
|
|
|
|
return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx)
|
|
|
|
def sample_surface(self, n, net):
|
|
|
|
# Sample surface using random tracing (resample until num_samples is reached)
|
|
|
|
timer = PerfTimer(activate=True)
|
|
|
|
with torch.no_grad():
|
|
i = 0
|
|
while i < 1000:
|
|
ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0
|
|
# this really should just return a torch array in the first place
|
|
ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device)
|
|
rb = self.forward(net, ray_o, ray_d)
|
|
|
|
#d = torch.abs(net(rb.x)[..., 0])
|
|
#idx = torch.where(d < 0.0003)
|
|
#pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0)
|
|
|
|
pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0)
|
|
if pts_pr.shape[0] >= n:
|
|
break
|
|
i += 1
|
|
if i == 50:
|
|
print('Taking an unusually long time to sample desired # of points.')
|
|
|
|
return pts_pr
|
|
|
|
|