| 
						
						
							
								
							
						
						
					 | 
					@ -30,6 +30,9 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] | 
					 | 
					 | 
					        dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.sdf_layers = len(dims_sdf) | 
					 | 
					 | 
					        self.sdf_layers = len(dims_sdf) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 使用 ModuleList 存储 sdf 层 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.sdf_modules = nn.ModuleList() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for layer in range(0, len(dims_sdf) - 1): | 
					 | 
					 | 
					        for layer in range(0, len(dims_sdf) - 1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer + 1 in skip_in: | 
					 | 
					 | 
					            if layer + 1 in skip_in: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                out_dim = dims_sdf[layer + 1] - d_in | 
					 | 
					 | 
					                out_dim = dims_sdf[layer + 1] - d_in | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -43,7 +46,8 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                else: | 
					 | 
					 | 
					                else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    torch.nn.init.constant_(lin.bias, 0.0) | 
					 | 
					 | 
					                    torch.nn.init.constant_(lin.bias, 0.0) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | 
					 | 
					 | 
					                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            setattr(self, "sdf_"+str(layer), lin) | 
					 | 
					 | 
					            self.sdf_modules.append(lin) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if geometric_init: | 
					 | 
					 | 
					        if geometric_init: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if beta > 0: | 
					 | 
					 | 
					            if beta > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                self.activation = nn.Softplus(beta=beta) | 
					 | 
					 | 
					                self.activation = nn.Softplus(beta=beta) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -55,10 +59,6 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.activation = Sine() | 
					 | 
					 | 
					            self.activation = Sine() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.final_activation = nn.ReLU() | 
					 | 
					 | 
					        self.final_activation = nn.ReLU() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # composite f_i to h | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					    def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ''' | 
					 | 
					 | 
					        ''' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        :param feature_matrix: 形状为 (B, P, D) 的特征矩阵 | 
					 | 
					 | 
					        :param feature_matrix: 形状为 (B, P, D) 的特征矩阵 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -73,10 +73,10 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 展平处理 (B*P, D) | 
					 | 
					 | 
					        # 展平处理 (B*P, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        x = feature_matrix.view(-1, D) | 
					 | 
					 | 
					        x = feature_matrix.view(-1, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        for layer in range(0, self.sdf_layers - 1): | 
					 | 
					 | 
					        # 使用枚举遍历 sdf_modules | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            lin = getattr(self, "sdf_" + str(layer)) | 
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            if layer in self.skip_in: | 
					 | 
					 | 
					            if layer in self.skip_in: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                x = torch.cat([x, x], -1) / np.sqrt(2)  # Fix undefined 'input' | 
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            x = lin(x) | 
					 | 
					 | 
					            x = lin(x) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -100,10 +100,9 @@ class Decoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 直接使用输入的特征矩阵,因为形状已经是 (S, D) | 
					 | 
					 | 
					        # 直接使用输入的特征矩阵,因为形状已经是 (S, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        x = feature_matrix | 
					 | 
					 | 
					        x = feature_matrix | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        for layer in range(0, self.sdf_layers - 1): | 
					 | 
					 | 
					        for layer, lin in enumerate(self.sdf_modules): | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					            lin = getattr(self, "sdf_" + str(layer)) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            if layer in self.skip_in: | 
					 | 
					 | 
					            if layer in self.skip_in: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                x = torch.cat([x, x], -1) / np.sqrt(2) | 
					 | 
					 | 
					                x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0))  # 使用 torch.sqrt | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            x = lin(x) | 
					 | 
					 | 
					            x = lin(x) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
					 | 
					 | 
					            if layer < self.sdf_layers - 2: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |