| 
						
						
							
								
							
						
						
					 | 
					@ -20,7 +20,7 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        skip_in: Tuple[int, ...] = (), | 
					 | 
					 | 
					        skip_in: Tuple[int, ...] = (), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        flag_convex: bool = True, | 
					 | 
					 | 
					        flag_convex: bool = True, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        geometric_init: bool = True, | 
					 | 
					 | 
					        geometric_init: bool = True, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        radius_init: float = 1, | 
					 | 
					 | 
					        radius_init: float = 0.5, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        beta: float = 100, | 
					 | 
					 | 
					        beta: float = 100, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    ) -> None: | 
					 | 
					 | 
					    ) -> None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        super().__init__() | 
					 | 
					 | 
					        super().__init__() | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -66,7 +66,7 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            #siren | 
					 | 
					 | 
					            #siren | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.activation = Sine() | 
					 | 
					 | 
					            self.activation = Sine() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.final_activation = nn.ReLU() | 
					 | 
					 | 
					        self.final_activation = nn.Tanh() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					    def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ''' | 
					 | 
					 | 
					        ''' | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -86,7 +86,7 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer in self.skip_in: | 
					 | 
					 | 
					            if layer in self.skip_in: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					            #logger.print_tensor_stats(f"layer-{layer}>x", x) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            x = lin(x) | 
					 | 
					 | 
					            x = lin(x) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                x = self.activation(x) | 
					 | 
					 | 
					                x = self.activation(x) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -112,7 +112,7 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer in self.skip_in: | 
					 | 
					 | 
					            if layer in self.skip_in: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					            #logger.print_tensor_stats(f"layer-{layer}>x", x) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            x = lin(x) | 
					 | 
					 | 
					            x = lin(x) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                x = self.activation(x) | 
					 | 
					 | 
					                x = self.activation(x) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |