Browse Source
			
			
			
			
				
		- Implemented `BaseTracer` class with configurable tracing parameters - Created `RenderBuffer` dataclass for flexible rendering data management - Added `SphereTracer` with advanced ray marching and surface tracing methods - Introduced utility functions for parameter setting and gradient computation - Supported various tensor operations and transformations in render bufferNH-Rep
				 3 changed files with 456 additions and 0 deletions
			
			
		@ -0,0 +1,58 @@ | 
				
			|||||
 | 
					# The MIT License (MIT) | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			||||
 | 
					# this software and associated documentation files (the "Software"), to deal in | 
				
			||||
 | 
					# the Software without restriction, including without limitation the rights to | 
				
			||||
 | 
					# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			||||
 | 
					# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			||||
 | 
					# subject to the following conditions: | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# The above copyright notice and this permission notice shall be included in all | 
				
			||||
 | 
					# copies or substantial portions of the Software. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			||||
 | 
					# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			||||
 | 
					# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			||||
 | 
					# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			||||
 | 
					# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					# This is the bridge between an argparse based approach and a non-argparse one | 
				
			||||
 | 
					def setparam(args, param, paramstr): | 
				
			||||
 | 
					    argsparam = getattr(args, paramstr, None) | 
				
			||||
 | 
					    if param is not None or argsparam is None: | 
				
			||||
 | 
					        return param | 
				
			||||
 | 
					    else: | 
				
			||||
 | 
					        return argsparam | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class BaseTracer(object): | 
				
			||||
 | 
					    """Virtual base class for tracer""" | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def __init__(self, | 
				
			||||
 | 
					        args                 = None, | 
				
			||||
 | 
					        camera_clamp : list  = None, | 
				
			||||
 | 
					        step_size    : float = None, | 
				
			||||
 | 
					        grad_method  : str   = None, | 
				
			||||
 | 
					        num_steps    : int   = None, # samples for raymaching, iterations for sphere trace | 
				
			||||
 | 
					        min_dis      : float = None):  | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        self.args = args | 
				
			||||
 | 
					        self.camera_clamp = setparam(args, camera_clamp, 'camera_clamp') | 
				
			||||
 | 
					        self.step_size = setparam(args, step_size, 'step_size') | 
				
			||||
 | 
					        self.grad_method = setparam(args, grad_method, 'grad_method') | 
				
			||||
 | 
					        self.num_steps = setparam(args, num_steps, 'num_steps') | 
				
			||||
 | 
					        self.min_dis = setparam(args, min_dis, 'min_dis') | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        self.inv_num_steps = 1.0 / self.num_steps | 
				
			||||
 | 
					        self.diagonal = np.sqrt(3) * 2.0 | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def __call__(self, *args, **kwargs): | 
				
			||||
 | 
					        return self.forward(*args, **kwargs) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def forward(self, net, ray_o, ray_d): | 
				
			||||
 | 
					        """Base implementation for forward""" | 
				
			||||
 | 
					        raise NotImplementedError | 
				
			||||
@ -0,0 +1,152 @@ | 
				
			|||||
 | 
					# The MIT License (MIT) | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			||||
 | 
					# this software and associated documentation files (the "Software"), to deal in | 
				
			||||
 | 
					# the Software without restriction, including without limitation the rights to | 
				
			||||
 | 
					# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			||||
 | 
					# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			||||
 | 
					# subject to the following conditions: | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# The above copyright notice and this permission notice shall be included in all | 
				
			||||
 | 
					# copies or substantial portions of the Software. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			||||
 | 
					# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			||||
 | 
					# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			||||
 | 
					# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			||||
 | 
					# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					from dataclasses import asdict, astuple, fields, dataclass | 
				
			||||
 | 
					from typing import Union, List | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					import torch | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					@dataclass | 
				
			||||
 | 
					class RenderBuffer: | 
				
			||||
 | 
					    x : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    min_x : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    hit : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    depth : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    relative_depth : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    normal : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    rgb : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    shadow : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    ao : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    view : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    err : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					    albedo : Union[torch.Tensor, None] = None | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def __iter__(self):  | 
				
			||||
 | 
					        return iter(astuple(self)) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def __add__(self, other): | 
				
			||||
 | 
					        def _proc(pair): | 
				
			||||
 | 
					            if None not in pair: | 
				
			||||
 | 
					                return torch.cat(pair) | 
				
			||||
 | 
					            elif pair[0] is not None and pair[1] is None: | 
				
			||||
 | 
					                return pair[0] | 
				
			||||
 | 
					            elif pair[0] is None and pair[1] is not None: | 
				
			||||
 | 
					                return pair[1] | 
				
			||||
 | 
					            else: | 
				
			||||
 | 
					                return None | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        return self.__class__(*(map(_proc, zip(self, other)))) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def _apply(self, fn): | 
				
			||||
 | 
					        data = {}  | 
				
			||||
 | 
					        for f in fields(self): | 
				
			||||
 | 
					            attr = getattr(self, f.name) | 
				
			||||
 | 
					            data[f.name] = None if attr is None else fn(attr) | 
				
			||||
 | 
					        return self.__class__(**data) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def cuda(self): | 
				
			||||
 | 
					        fn = lambda x : x.cuda() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					          | 
				
			||||
 | 
					    def cpu(self): | 
				
			||||
 | 
					        fn = lambda x : x.cpu() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def detach(self): | 
				
			||||
 | 
					        fn = lambda x : x.detach() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def byte(self): | 
				
			||||
 | 
					        fn = lambda x : x.byte() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def reshape(self, *dims : List[int]): | 
				
			||||
 | 
					        fn = lambda x : x.reshape(*dims) | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def transpose(self): | 
				
			||||
 | 
					        fn = lambda x : x.permute(1,0,2) | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def numpy(self): | 
				
			||||
 | 
					        fn = lambda x : x.numpy() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def float(self): | 
				
			||||
 | 
					        fn = lambda x : x.float() | 
				
			||||
 | 
					        return self._apply(fn) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def exrdict(self): | 
				
			||||
 | 
					        _dict = asdict(self) | 
				
			||||
 | 
					        _dict = {k:v for k,v in _dict.items() if v is not None} | 
				
			||||
 | 
					        if 'rgb' in _dict: | 
				
			||||
 | 
					            _dict['default'] = _dict['rgb'] | 
				
			||||
 | 
					            del _dict['rgb'] | 
				
			||||
 | 
					        return _dict | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def image(self): | 
				
			||||
 | 
					        # Unfinished | 
				
			||||
 | 
					        norm = lambda arr : ((arr + 1.0) / 2.0) if arr is not None else None | 
				
			||||
 | 
					        bwrgb = lambda arr : torch.cat([arr]*3, dim=-1) if arr is not None else None | 
				
			||||
 | 
					        rgb8 = lambda arr : (arr * 255.0) if arr is not None else None | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        #x = rgb8(norm(self.x)) | 
				
			||||
 | 
					        #min_x = rgb8(norm(self.min_x)) | 
				
			||||
 | 
					        hit = rgb8(bwrgb(self.hit)) | 
				
			||||
 | 
					        depth = rgb8(bwrgb(self.relative_depth)) | 
				
			||||
 | 
					        normal = rgb8(norm(self.normal)) | 
				
			||||
 | 
					        #ao = 1.0 - rgb8(bwrgb(self.ao)) | 
				
			||||
 | 
					        rgb = rgb8(self.rgb) | 
				
			||||
 | 
					        #return self.__class__(x=x, min_x=min_x, hit=hit, depth=depth, normal=normal, ao=ao, rgb=rgb) | 
				
			||||
 | 
					        return self.__class__(hit=hit, normal=normal, rgb=rgb, depth=depth) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    @staticmethod | 
				
			||||
 | 
					    def mean(*rblst): | 
				
			||||
 | 
					        rb = RenderBuffer() | 
				
			||||
 | 
					        n = len(rblst) | 
				
			||||
 | 
					        for _rb in rblst: | 
				
			||||
 | 
					            for f in fields(_rb): | 
				
			||||
 | 
					                attr = getattr(_rb, f.name) | 
				
			||||
 | 
					                dest = getattr(rb, f.name) | 
				
			||||
 | 
					                if attr is not None and dest is not None: | 
				
			||||
 | 
					                    setattr(rb, f.name, dest + attr) | 
				
			||||
 | 
					                elif dest is None: | 
				
			||||
 | 
					                    setattr(rb, f.name, attr) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        for f in fields(rb): | 
				
			||||
 | 
					            dest = getattr(rb, f.name) | 
				
			||||
 | 
					            if dest is not None: | 
				
			||||
 | 
					                setattr(rb, f.name, dest / float(n)) | 
				
			||||
 | 
					        return rb | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					if __name__ == '__main__': | 
				
			||||
 | 
					    hw = 1024 | 
				
			||||
 | 
					    hit = torch.zeros(hw, hw, 1).bool() | 
				
			||||
 | 
					    rgb = torch.zeros(hw, hw, 3).float() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    rb0 = RenderBuffer(hit=hit.clone(), rgb=rgb.clone()) | 
				
			||||
 | 
					    rb1 = RenderBuffer(hit=hit, rgb=rgb) | 
				
			||||
 | 
					    rb1.rgb += 1.0 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    avg = RenderBuffer.mean(rb0, rb1) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    import pdb; pdb.set_trace() | 
				
			||||
@ -0,0 +1,246 @@ | 
				
			|||||
 | 
					# The MIT License (MIT) | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			||||
 | 
					# this software and associated documentation files (the "Software"), to deal in | 
				
			||||
 | 
					# the Software without restriction, including without limitation the rights to | 
				
			||||
 | 
					# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			||||
 | 
					# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			||||
 | 
					# subject to the following conditions: | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# The above copyright notice and this permission notice shall be included in all | 
				
			||||
 | 
					# copies or substantial portions of the Software. | 
				
			||||
 | 
					# | 
				
			||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			||||
 | 
					# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			||||
 | 
					# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			||||
 | 
					# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			||||
 | 
					# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					import torch | 
				
			||||
 | 
					import torch.nn.functional as F | 
				
			||||
 | 
					import torch.nn as nn | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					from lib.utils import PerfTimer | 
				
			||||
 | 
					from lib.diffutils import gradient | 
				
			||||
 | 
					from lib.geoutils import sample_unif_sphere | 
				
			||||
 | 
					from lib.tracer.RenderBuffer import RenderBuffer | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					from lib.PsDebugger import PsDebugger | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					from base_tracer import BaseTracer | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					from sol_nglod import aabb | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class SphereTracer(BaseTracer): | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def forward(self, net, ray_o, ray_d): | 
				
			||||
 | 
					        """Native implementation of sphere tracing.""" | 
				
			||||
 | 
					        timer = PerfTimer(activate=False) | 
				
			||||
 | 
					        nettimer = PerfTimer(activate=False) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # Distanace from ray origin | 
				
			||||
 | 
					        t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # Position in model space | 
				
			||||
 | 
					        x = torch.addcmul(ray_o, ray_d, t) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        cond = torch.ones_like(t).bool()[:,0] | 
				
			||||
 | 
					        x, t, cond = aabb(ray_o, ray_d) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        normal = torch.zeros_like(x) | 
				
			||||
 | 
					        # This function is in fact differentiable, but we treat it as if it's not, because | 
				
			||||
 | 
					        # it evaluates a very long chain of recursive neural networks (essentially a NN with depth of | 
				
			||||
 | 
					        # ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit | 
				
			||||
 | 
					        # locations, where additional quantities (normal, depth, segmentation) can be determined. The | 
				
			||||
 | 
					        # gradients will propagate only to these locations.  | 
				
			||||
 | 
					        with torch.no_grad(): | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            d = net(x) | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            dprev = d.clone() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # If cond is TRUE, then the corresponding ray has not hit yet. | 
				
			||||
 | 
					            # OR, the corresponding ray has exit the clipping plane. | 
				
			||||
 | 
					            #cond = torch.ones_like(d).bool()[:,0] | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # If miss is TRUE, then the corresponding ray has missed entirely. | 
				
			||||
 | 
					            hit = torch.zeros_like(d).byte() | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            for i in range(self.num_steps): | 
				
			||||
 | 
					                timer.check("start") | 
				
			||||
 | 
					                # 1. Check if ray hits. | 
				
			||||
 | 
					                #hit = (torch.abs(d) < self._MIN_DIS)[:,0]  | 
				
			||||
 | 
					                # 2. Check that the sphere tracing is not oscillating | 
				
			||||
 | 
					                #hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # 3. Check that the ray has not exit the far clipping plane. | 
				
			||||
 | 
					                #cond = (torch.abs(t) < self.clamp[1])[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                hit = (torch.abs(t) < self.camera_clamp[1])[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # 1. not hit surface | 
				
			||||
 | 
					                cond = cond & (torch.abs(d) > self.min_dis)[:,0]  | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                # 2. not oscillating | 
				
			||||
 | 
					                cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # 3. not a hit | 
				
			||||
 | 
					                cond = cond & hit | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                #cond = cond & ~hit | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # If the sum is 0, that means that all rays have hit, or missed. | 
				
			||||
 | 
					                if not cond.any(): | 
				
			||||
 | 
					                    break | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                # Advance the x, by updating with a new t | 
				
			||||
 | 
					                x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # Store the previous distance | 
				
			||||
 | 
					                dprev = torch.where(cond.unsqueeze(1), d, dprev) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                nettimer.check("nstart") | 
				
			||||
 | 
					                # Update the distance to surface at x | 
				
			||||
 | 
					                d[cond] = net(x[cond]) * self.step_size | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                nettimer.check("nend") | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # Update the distance from origin  | 
				
			||||
 | 
					                t = torch.where(cond.view(cond.shape[0], 1), t+d, t) | 
				
			||||
 | 
					                timer.check("end") | 
				
			||||
 | 
					     | 
				
			||||
 | 
					        # AABB cull  | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # The function will return  | 
				
			||||
 | 
					        #  x: the final model-space coordinate of the render | 
				
			||||
 | 
					        #  t: the final distance from origin | 
				
			||||
 | 
					        #  d: the final distance value from | 
				
			||||
 | 
					        #  miss: a vector containing bools of whether each ray was a hit or miss | 
				
			||||
 | 
					        #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        grad = gradient(x[hit], net, method=self.grad_method) | 
				
			||||
 | 
					        _normal = F.normalize(grad, p=2, dim=-1, eps=1e-5) | 
				
			||||
 | 
					        normal[hit] = _normal | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        return RenderBuffer(x=x, depth=t, hit=hit, normal=normal) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def get_min(self, net, ray_o, ray_d): | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        timer = PerfTimer(activate=False) | 
				
			||||
 | 
					        nettimer = PerfTimer(activate=False) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # Distance from ray origin | 
				
			||||
 | 
					        t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # Position in model space | 
				
			||||
 | 
					        x = torch.addcmul(ray_o, ray_d, t) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        x, t, hit = aabb(ray_o, ray_d); | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        normal = torch.zeros_like(x) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        with torch.no_grad(): | 
				
			||||
 | 
					            d = net(x) | 
				
			||||
 | 
					            dprev = d.clone() | 
				
			||||
 | 
					            mind = d.clone() | 
				
			||||
 | 
					            minx = x.clone() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # If cond is TRUE, then the corresponding ray has not hit yet. | 
				
			||||
 | 
					            # OR, the corresponding ray has exit the clipping plane. | 
				
			||||
 | 
					            cond = torch.ones_like(d).bool()[:,0] | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # If miss is TRUE, then the corresponding ray has missed entirely. | 
				
			||||
 | 
					            hit = torch.zeros_like(d).byte() | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            for i in range(self.num_steps): | 
				
			||||
 | 
					                timer.check("start") | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                hit = (torch.abs(t) < self.camera_clamp[1])[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # 1. not hit surface | 
				
			||||
 | 
					                cond = (torch.abs(d) > self.min_dis)[:,0]  | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                # 2. not oscillating | 
				
			||||
 | 
					                cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # 3. not a hit | 
				
			||||
 | 
					                cond = cond & hit | 
				
			||||
 | 
					         | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                #cond = cond & ~hit | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # If the sum is 0, that means that all rays have hit, or missed. | 
				
			||||
 | 
					                if not cond.any(): | 
				
			||||
 | 
					                    break | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                # Advance the x, by updating with a new t | 
				
			||||
 | 
					                x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                new_mins = (d<mind)[...,0] | 
				
			||||
 | 
					                mind[new_mins] = d[new_mins] | 
				
			||||
 | 
					                minx[new_mins] = x[new_mins] | 
				
			||||
 | 
					             | 
				
			||||
 | 
					                # Store the previous distance | 
				
			||||
 | 
					                dprev = torch.where(cond.unsqueeze(1), d, dprev) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                nettimer.check("nstart") | 
				
			||||
 | 
					                # Update the distance to surface at x | 
				
			||||
 | 
					                d[cond] = net(x[cond]) * self.step_size | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                nettimer.check("nend") | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                # Update the distance from origin  | 
				
			||||
 | 
					                t = torch.where(cond.view(cond.shape[0], 1), t+d, t) | 
				
			||||
 | 
					                timer.check("end") | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # AABB cull  | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) | 
				
			||||
 | 
					        #hit = torch.ones_like(d).byte()[...,0] | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # The function will return  | 
				
			||||
 | 
					        #  x: the final model-space coordinate of the render | 
				
			||||
 | 
					        #  t: the final distance from origin | 
				
			||||
 | 
					        #  d: the final distance value from | 
				
			||||
 | 
					        #  miss: a vector containing bools of whether each ray was a hit or miss | 
				
			||||
 | 
					        #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) | 
				
			||||
 | 
					        _normal = gradient(x[hit], net, method=self.grad_method) | 
				
			||||
 | 
					        normal[hit] = _normal | 
				
			||||
 | 
					         | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def sample_surface(self, n, net): | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # Sample surface using random tracing (resample until num_samples is reached) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        timer = PerfTimer(activate=True) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        with torch.no_grad(): | 
				
			||||
 | 
					            i = 0 | 
				
			||||
 | 
					            while i < 1000: | 
				
			||||
 | 
					                ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0 | 
				
			||||
 | 
					                # this really should just return a torch array in the first place | 
				
			||||
 | 
					                ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device) | 
				
			||||
 | 
					                rb = self.forward(net, ray_o, ray_d) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					                #d = torch.abs(net(rb.x)[..., 0]) | 
				
			||||
 | 
					                #idx = torch.where(d < 0.0003) | 
				
			||||
 | 
					                #pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0) | 
				
			||||
 | 
					                 | 
				
			||||
 | 
					                pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0) | 
				
			||||
 | 
					                if pts_pr.shape[0] >= n: | 
				
			||||
 | 
					                    break | 
				
			||||
 | 
					                i += 1 | 
				
			||||
 | 
					                if i == 50: | 
				
			||||
 | 
					                    print('Taking an unusually long time to sample desired # of points.') | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        return pts_pr | 
				
			||||
 | 
					
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue