Browse Source
			
			
			
			
				
		- Implemented `BaseTracer` class with configurable tracing parameters - Created `RenderBuffer` dataclass for flexible rendering data management - Added `SphereTracer` with advanced ray marching and surface tracing methods - Introduced utility functions for parameter setting and gradient computation - Supported various tensor operations and transformations in render bufferNH-Rep
				 3 changed files with 456 additions and 0 deletions
			
			
		@ -0,0 +1,58 @@ | 
				
			|||
# The MIT License (MIT) | 
				
			|||
# | 
				
			|||
# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			|||
# | 
				
			|||
# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			|||
# this software and associated documentation files (the "Software"), to deal in | 
				
			|||
# the Software without restriction, including without limitation the rights to | 
				
			|||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			|||
# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			|||
# subject to the following conditions: | 
				
			|||
# | 
				
			|||
# The above copyright notice and this permission notice shall be included in all | 
				
			|||
# copies or substantial portions of the Software. | 
				
			|||
# | 
				
			|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			|||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			|||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			|||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			|||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			|||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			|||
 | 
				
			|||
import numpy as np | 
				
			|||
 | 
				
			|||
# This is the bridge between an argparse based approach and a non-argparse one | 
				
			|||
def setparam(args, param, paramstr): | 
				
			|||
    argsparam = getattr(args, paramstr, None) | 
				
			|||
    if param is not None or argsparam is None: | 
				
			|||
        return param | 
				
			|||
    else: | 
				
			|||
        return argsparam | 
				
			|||
 | 
				
			|||
class BaseTracer(object): | 
				
			|||
    """Virtual base class for tracer""" | 
				
			|||
 | 
				
			|||
    def __init__(self, | 
				
			|||
        args                 = None, | 
				
			|||
        camera_clamp : list  = None, | 
				
			|||
        step_size    : float = None, | 
				
			|||
        grad_method  : str   = None, | 
				
			|||
        num_steps    : int   = None, # samples for raymaching, iterations for sphere trace | 
				
			|||
        min_dis      : float = None):  | 
				
			|||
 | 
				
			|||
        self.args = args | 
				
			|||
        self.camera_clamp = setparam(args, camera_clamp, 'camera_clamp') | 
				
			|||
        self.step_size = setparam(args, step_size, 'step_size') | 
				
			|||
        self.grad_method = setparam(args, grad_method, 'grad_method') | 
				
			|||
        self.num_steps = setparam(args, num_steps, 'num_steps') | 
				
			|||
        self.min_dis = setparam(args, min_dis, 'min_dis') | 
				
			|||
 | 
				
			|||
        self.inv_num_steps = 1.0 / self.num_steps | 
				
			|||
        self.diagonal = np.sqrt(3) * 2.0 | 
				
			|||
     | 
				
			|||
    def __call__(self, *args, **kwargs): | 
				
			|||
        return self.forward(*args, **kwargs) | 
				
			|||
 | 
				
			|||
    def forward(self, net, ray_o, ray_d): | 
				
			|||
        """Base implementation for forward""" | 
				
			|||
        raise NotImplementedError | 
				
			|||
@ -0,0 +1,152 @@ | 
				
			|||
# The MIT License (MIT) | 
				
			|||
# | 
				
			|||
# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			|||
# | 
				
			|||
# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			|||
# this software and associated documentation files (the "Software"), to deal in | 
				
			|||
# the Software without restriction, including without limitation the rights to | 
				
			|||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			|||
# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			|||
# subject to the following conditions: | 
				
			|||
# | 
				
			|||
# The above copyright notice and this permission notice shall be included in all | 
				
			|||
# copies or substantial portions of the Software. | 
				
			|||
# | 
				
			|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			|||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			|||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			|||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			|||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			|||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			|||
 | 
				
			|||
from dataclasses import asdict, astuple, fields, dataclass | 
				
			|||
from typing import Union, List | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
import numpy as np | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class RenderBuffer: | 
				
			|||
    x : Union[torch.Tensor, None] = None | 
				
			|||
    min_x : Union[torch.Tensor, None] = None | 
				
			|||
    hit : Union[torch.Tensor, None] = None | 
				
			|||
    depth : Union[torch.Tensor, None] = None | 
				
			|||
    relative_depth : Union[torch.Tensor, None] = None | 
				
			|||
    normal : Union[torch.Tensor, None] = None | 
				
			|||
    rgb : Union[torch.Tensor, None] = None | 
				
			|||
    shadow : Union[torch.Tensor, None] = None | 
				
			|||
    ao : Union[torch.Tensor, None] = None | 
				
			|||
    view : Union[torch.Tensor, None] = None | 
				
			|||
    err : Union[torch.Tensor, None] = None | 
				
			|||
    albedo : Union[torch.Tensor, None] = None | 
				
			|||
 | 
				
			|||
    def __iter__(self):  | 
				
			|||
        return iter(astuple(self)) | 
				
			|||
 | 
				
			|||
    def __add__(self, other): | 
				
			|||
        def _proc(pair): | 
				
			|||
            if None not in pair: | 
				
			|||
                return torch.cat(pair) | 
				
			|||
            elif pair[0] is not None and pair[1] is None: | 
				
			|||
                return pair[0] | 
				
			|||
            elif pair[0] is None and pair[1] is not None: | 
				
			|||
                return pair[1] | 
				
			|||
            else: | 
				
			|||
                return None | 
				
			|||
 | 
				
			|||
        return self.__class__(*(map(_proc, zip(self, other)))) | 
				
			|||
 | 
				
			|||
    def _apply(self, fn): | 
				
			|||
        data = {}  | 
				
			|||
        for f in fields(self): | 
				
			|||
            attr = getattr(self, f.name) | 
				
			|||
            data[f.name] = None if attr is None else fn(attr) | 
				
			|||
        return self.__class__(**data) | 
				
			|||
 | 
				
			|||
    def cuda(self): | 
				
			|||
        fn = lambda x : x.cuda() | 
				
			|||
        return self._apply(fn) | 
				
			|||
          | 
				
			|||
    def cpu(self): | 
				
			|||
        fn = lambda x : x.cpu() | 
				
			|||
        return self._apply(fn) | 
				
			|||
     | 
				
			|||
    def detach(self): | 
				
			|||
        fn = lambda x : x.detach() | 
				
			|||
        return self._apply(fn) | 
				
			|||
     | 
				
			|||
    def byte(self): | 
				
			|||
        fn = lambda x : x.byte() | 
				
			|||
        return self._apply(fn) | 
				
			|||
 | 
				
			|||
    def reshape(self, *dims : List[int]): | 
				
			|||
        fn = lambda x : x.reshape(*dims) | 
				
			|||
        return self._apply(fn) | 
				
			|||
     | 
				
			|||
    def transpose(self): | 
				
			|||
        fn = lambda x : x.permute(1,0,2) | 
				
			|||
        return self._apply(fn) | 
				
			|||
     | 
				
			|||
    def numpy(self): | 
				
			|||
        fn = lambda x : x.numpy() | 
				
			|||
        return self._apply(fn) | 
				
			|||
     | 
				
			|||
    def float(self): | 
				
			|||
        fn = lambda x : x.float() | 
				
			|||
        return self._apply(fn) | 
				
			|||
 | 
				
			|||
    def exrdict(self): | 
				
			|||
        _dict = asdict(self) | 
				
			|||
        _dict = {k:v for k,v in _dict.items() if v is not None} | 
				
			|||
        if 'rgb' in _dict: | 
				
			|||
            _dict['default'] = _dict['rgb'] | 
				
			|||
            del _dict['rgb'] | 
				
			|||
        return _dict | 
				
			|||
 | 
				
			|||
    def image(self): | 
				
			|||
        # Unfinished | 
				
			|||
        norm = lambda arr : ((arr + 1.0) / 2.0) if arr is not None else None | 
				
			|||
        bwrgb = lambda arr : torch.cat([arr]*3, dim=-1) if arr is not None else None | 
				
			|||
        rgb8 = lambda arr : (arr * 255.0) if arr is not None else None | 
				
			|||
 | 
				
			|||
        #x = rgb8(norm(self.x)) | 
				
			|||
        #min_x = rgb8(norm(self.min_x)) | 
				
			|||
        hit = rgb8(bwrgb(self.hit)) | 
				
			|||
        depth = rgb8(bwrgb(self.relative_depth)) | 
				
			|||
        normal = rgb8(norm(self.normal)) | 
				
			|||
        #ao = 1.0 - rgb8(bwrgb(self.ao)) | 
				
			|||
        rgb = rgb8(self.rgb) | 
				
			|||
        #return self.__class__(x=x, min_x=min_x, hit=hit, depth=depth, normal=normal, ao=ao, rgb=rgb) | 
				
			|||
        return self.__class__(hit=hit, normal=normal, rgb=rgb, depth=depth) | 
				
			|||
 | 
				
			|||
    @staticmethod | 
				
			|||
    def mean(*rblst): | 
				
			|||
        rb = RenderBuffer() | 
				
			|||
        n = len(rblst) | 
				
			|||
        for _rb in rblst: | 
				
			|||
            for f in fields(_rb): | 
				
			|||
                attr = getattr(_rb, f.name) | 
				
			|||
                dest = getattr(rb, f.name) | 
				
			|||
                if attr is not None and dest is not None: | 
				
			|||
                    setattr(rb, f.name, dest + attr) | 
				
			|||
                elif dest is None: | 
				
			|||
                    setattr(rb, f.name, attr) | 
				
			|||
 | 
				
			|||
        for f in fields(rb): | 
				
			|||
            dest = getattr(rb, f.name) | 
				
			|||
            if dest is not None: | 
				
			|||
                setattr(rb, f.name, dest / float(n)) | 
				
			|||
        return rb | 
				
			|||
 | 
				
			|||
if __name__ == '__main__': | 
				
			|||
    hw = 1024 | 
				
			|||
    hit = torch.zeros(hw, hw, 1).bool() | 
				
			|||
    rgb = torch.zeros(hw, hw, 3).float() | 
				
			|||
 | 
				
			|||
    rb0 = RenderBuffer(hit=hit.clone(), rgb=rgb.clone()) | 
				
			|||
    rb1 = RenderBuffer(hit=hit, rgb=rgb) | 
				
			|||
    rb1.rgb += 1.0 | 
				
			|||
 | 
				
			|||
    avg = RenderBuffer.mean(rb0, rb1) | 
				
			|||
 | 
				
			|||
    import pdb; pdb.set_trace() | 
				
			|||
@ -0,0 +1,246 @@ | 
				
			|||
# The MIT License (MIT) | 
				
			|||
# | 
				
			|||
# Copyright (c) 2021, NVIDIA CORPORATION. | 
				
			|||
# | 
				
			|||
# Permission is hereby granted, free of charge, to any person obtaining a copy of | 
				
			|||
# this software and associated documentation files (the "Software"), to deal in | 
				
			|||
# the Software without restriction, including without limitation the rights to | 
				
			|||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | 
				
			|||
# the Software, and to permit persons to whom the Software is furnished to do so, | 
				
			|||
# subject to the following conditions: | 
				
			|||
# | 
				
			|||
# The above copyright notice and this permission notice shall be included in all | 
				
			|||
# copies or substantial portions of the Software. | 
				
			|||
# | 
				
			|||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
				
			|||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | 
				
			|||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | 
				
			|||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | 
				
			|||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | 
				
			|||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | 
				
			|||
 | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
import torch.nn.functional as F | 
				
			|||
import torch.nn as nn | 
				
			|||
import numpy as np | 
				
			|||
 | 
				
			|||
from lib.utils import PerfTimer | 
				
			|||
from lib.diffutils import gradient | 
				
			|||
from lib.geoutils import sample_unif_sphere | 
				
			|||
from lib.tracer.RenderBuffer import RenderBuffer | 
				
			|||
 | 
				
			|||
from lib.PsDebugger import PsDebugger | 
				
			|||
 | 
				
			|||
from base_tracer import BaseTracer | 
				
			|||
 | 
				
			|||
from sol_nglod import aabb | 
				
			|||
 | 
				
			|||
class SphereTracer(BaseTracer): | 
				
			|||
 | 
				
			|||
    def forward(self, net, ray_o, ray_d): | 
				
			|||
        """Native implementation of sphere tracing.""" | 
				
			|||
        timer = PerfTimer(activate=False) | 
				
			|||
        nettimer = PerfTimer(activate=False) | 
				
			|||
 | 
				
			|||
        # Distanace from ray origin | 
				
			|||
        t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) | 
				
			|||
 | 
				
			|||
        # Position in model space | 
				
			|||
        x = torch.addcmul(ray_o, ray_d, t) | 
				
			|||
 | 
				
			|||
        cond = torch.ones_like(t).bool()[:,0] | 
				
			|||
        x, t, cond = aabb(ray_o, ray_d) | 
				
			|||
 | 
				
			|||
        normal = torch.zeros_like(x) | 
				
			|||
        # This function is in fact differentiable, but we treat it as if it's not, because | 
				
			|||
        # it evaluates a very long chain of recursive neural networks (essentially a NN with depth of | 
				
			|||
        # ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit | 
				
			|||
        # locations, where additional quantities (normal, depth, segmentation) can be determined. The | 
				
			|||
        # gradients will propagate only to these locations.  | 
				
			|||
        with torch.no_grad(): | 
				
			|||
 | 
				
			|||
            d = net(x) | 
				
			|||
             | 
				
			|||
            dprev = d.clone() | 
				
			|||
 | 
				
			|||
            # If cond is TRUE, then the corresponding ray has not hit yet. | 
				
			|||
            # OR, the corresponding ray has exit the clipping plane. | 
				
			|||
            #cond = torch.ones_like(d).bool()[:,0] | 
				
			|||
 | 
				
			|||
            # If miss is TRUE, then the corresponding ray has missed entirely. | 
				
			|||
            hit = torch.zeros_like(d).byte() | 
				
			|||
             | 
				
			|||
            for i in range(self.num_steps): | 
				
			|||
                timer.check("start") | 
				
			|||
                # 1. Check if ray hits. | 
				
			|||
                #hit = (torch.abs(d) < self._MIN_DIS)[:,0]  | 
				
			|||
                # 2. Check that the sphere tracing is not oscillating | 
				
			|||
                #hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0] | 
				
			|||
                 | 
				
			|||
                # 3. Check that the ray has not exit the far clipping plane. | 
				
			|||
                #cond = (torch.abs(t) < self.clamp[1])[:,0] | 
				
			|||
                 | 
				
			|||
                hit = (torch.abs(t) < self.camera_clamp[1])[:,0] | 
				
			|||
                 | 
				
			|||
                # 1. not hit surface | 
				
			|||
                cond = cond & (torch.abs(d) > self.min_dis)[:,0]  | 
				
			|||
 | 
				
			|||
                # 2. not oscillating | 
				
			|||
                cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] | 
				
			|||
                 | 
				
			|||
                # 3. not a hit | 
				
			|||
                cond = cond & hit | 
				
			|||
                 | 
				
			|||
                #cond = cond & ~hit | 
				
			|||
                 | 
				
			|||
                # If the sum is 0, that means that all rays have hit, or missed. | 
				
			|||
                if not cond.any(): | 
				
			|||
                    break | 
				
			|||
 | 
				
			|||
                # Advance the x, by updating with a new t | 
				
			|||
                x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) | 
				
			|||
                 | 
				
			|||
                # Store the previous distance | 
				
			|||
                dprev = torch.where(cond.unsqueeze(1), d, dprev) | 
				
			|||
 | 
				
			|||
                nettimer.check("nstart") | 
				
			|||
                # Update the distance to surface at x | 
				
			|||
                d[cond] = net(x[cond]) * self.step_size | 
				
			|||
 | 
				
			|||
                nettimer.check("nend") | 
				
			|||
                 | 
				
			|||
                # Update the distance from origin  | 
				
			|||
                t = torch.where(cond.view(cond.shape[0], 1), t+d, t) | 
				
			|||
                timer.check("end") | 
				
			|||
     | 
				
			|||
        # AABB cull  | 
				
			|||
 | 
				
			|||
        hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) | 
				
			|||
         | 
				
			|||
        # The function will return  | 
				
			|||
        #  x: the final model-space coordinate of the render | 
				
			|||
        #  t: the final distance from origin | 
				
			|||
        #  d: the final distance value from | 
				
			|||
        #  miss: a vector containing bools of whether each ray was a hit or miss | 
				
			|||
        #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) | 
				
			|||
         | 
				
			|||
        grad = gradient(x[hit], net, method=self.grad_method) | 
				
			|||
        _normal = F.normalize(grad, p=2, dim=-1, eps=1e-5) | 
				
			|||
        normal[hit] = _normal | 
				
			|||
 | 
				
			|||
        return RenderBuffer(x=x, depth=t, hit=hit, normal=normal) | 
				
			|||
     | 
				
			|||
    def get_min(self, net, ray_o, ray_d): | 
				
			|||
 | 
				
			|||
        timer = PerfTimer(activate=False) | 
				
			|||
        nettimer = PerfTimer(activate=False) | 
				
			|||
 | 
				
			|||
        # Distance from ray origin | 
				
			|||
        t = torch.zeros(ray_o.shape[0], 1, device=ray_o.device) | 
				
			|||
 | 
				
			|||
        # Position in model space | 
				
			|||
        x = torch.addcmul(ray_o, ray_d, t) | 
				
			|||
 | 
				
			|||
        x, t, hit = aabb(ray_o, ray_d); | 
				
			|||
 | 
				
			|||
        normal = torch.zeros_like(x) | 
				
			|||
 | 
				
			|||
        with torch.no_grad(): | 
				
			|||
            d = net(x) | 
				
			|||
            dprev = d.clone() | 
				
			|||
            mind = d.clone() | 
				
			|||
            minx = x.clone() | 
				
			|||
 | 
				
			|||
            # If cond is TRUE, then the corresponding ray has not hit yet. | 
				
			|||
            # OR, the corresponding ray has exit the clipping plane. | 
				
			|||
            cond = torch.ones_like(d).bool()[:,0] | 
				
			|||
 | 
				
			|||
            # If miss is TRUE, then the corresponding ray has missed entirely. | 
				
			|||
            hit = torch.zeros_like(d).byte() | 
				
			|||
             | 
				
			|||
            for i in range(self.num_steps): | 
				
			|||
                timer.check("start") | 
				
			|||
 | 
				
			|||
                hit = (torch.abs(t) < self.camera_clamp[1])[:,0] | 
				
			|||
                 | 
				
			|||
                # 1. not hit surface | 
				
			|||
                cond = (torch.abs(d) > self.min_dis)[:,0]  | 
				
			|||
 | 
				
			|||
                # 2. not oscillating | 
				
			|||
                cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)[:,0] | 
				
			|||
                 | 
				
			|||
                # 3. not a hit | 
				
			|||
                cond = cond & hit | 
				
			|||
         | 
				
			|||
                 | 
				
			|||
                #cond = cond & ~hit | 
				
			|||
                 | 
				
			|||
                # If the sum is 0, that means that all rays have hit, or missed. | 
				
			|||
                if not cond.any(): | 
				
			|||
                    break | 
				
			|||
 | 
				
			|||
                # Advance the x, by updating with a new t | 
				
			|||
                x = torch.where(cond.view(cond.shape[0], 1), torch.addcmul(ray_o, ray_d, t), x) | 
				
			|||
                 | 
				
			|||
                new_mins = (d<mind)[...,0] | 
				
			|||
                mind[new_mins] = d[new_mins] | 
				
			|||
                minx[new_mins] = x[new_mins] | 
				
			|||
             | 
				
			|||
                # Store the previous distance | 
				
			|||
                dprev = torch.where(cond.unsqueeze(1), d, dprev) | 
				
			|||
 | 
				
			|||
                nettimer.check("nstart") | 
				
			|||
                # Update the distance to surface at x | 
				
			|||
                d[cond] = net(x[cond]) * self.step_size | 
				
			|||
 | 
				
			|||
                nettimer.check("nend") | 
				
			|||
                 | 
				
			|||
                # Update the distance from origin  | 
				
			|||
                t = torch.where(cond.view(cond.shape[0], 1), t+d, t) | 
				
			|||
                timer.check("end") | 
				
			|||
 | 
				
			|||
        # AABB cull  | 
				
			|||
 | 
				
			|||
        hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1) | 
				
			|||
        #hit = torch.ones_like(d).byte()[...,0] | 
				
			|||
         | 
				
			|||
        # The function will return  | 
				
			|||
        #  x: the final model-space coordinate of the render | 
				
			|||
        #  t: the final distance from origin | 
				
			|||
        #  d: the final distance value from | 
				
			|||
        #  miss: a vector containing bools of whether each ray was a hit or miss | 
				
			|||
        #_normal = F.normalize(gradient(x[hit], net, method='finitediff'), p=2, dim=-1, eps=1e-5) | 
				
			|||
        _normal = gradient(x[hit], net, method=self.grad_method) | 
				
			|||
        normal[hit] = _normal | 
				
			|||
         | 
				
			|||
 | 
				
			|||
        return RenderBuffer(x=x, depth=t, hit=hit, normal=normal, minx=minx) | 
				
			|||
 | 
				
			|||
    def sample_surface(self, n, net): | 
				
			|||
         | 
				
			|||
        # Sample surface using random tracing (resample until num_samples is reached) | 
				
			|||
         | 
				
			|||
        timer = PerfTimer(activate=True) | 
				
			|||
         | 
				
			|||
        with torch.no_grad(): | 
				
			|||
            i = 0 | 
				
			|||
            while i < 1000: | 
				
			|||
                ray_o = torch.rand((n, 3), device=self.device) * 2.0 - 1.0 | 
				
			|||
                # this really should just return a torch array in the first place | 
				
			|||
                ray_d = torch.from_numpy(sample_unif_sphere(n)).float().to(self.device) | 
				
			|||
                rb = self.forward(net, ray_o, ray_d) | 
				
			|||
 | 
				
			|||
                #d = torch.abs(net(rb.x)[..., 0]) | 
				
			|||
                #idx = torch.where(d < 0.0003) | 
				
			|||
                #pts_pr = rb.x[idx] if i == 0 else torch.cat([pts_pr, rb.x[idx]], dim=0) | 
				
			|||
                 | 
				
			|||
                pts_pr = rb.x[rb.hit] if i == 0 else torch.cat([pts_pr, rb.x[rb.hit]], dim=0) | 
				
			|||
                if pts_pr.shape[0] >= n: | 
				
			|||
                    break | 
				
			|||
                i += 1 | 
				
			|||
                if i == 50: | 
				
			|||
                    print('Taking an unusually long time to sample desired # of points.') | 
				
			|||
         | 
				
			|||
        return pts_pr | 
				
			|||
 | 
				
			|||
					Loading…
					
					
				
		Reference in new issue