Browse Source

可以torch jit script

final
mckay 1 month ago
parent
commit
58e000dad3
  1. 21
      brep2sdf/networks/decoder.py
  2. 7
      brep2sdf/networks/network.py

21
brep2sdf/networks/decoder.py

@ -30,6 +30,9 @@ class Decoder(nn.Module):
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch]
self.sdf_layers = len(dims_sdf) self.sdf_layers = len(dims_sdf)
# 使用 ModuleList 存储 sdf 层
self.sdf_modules = nn.ModuleList()
for layer in range(0, len(dims_sdf) - 1): for layer in range(0, len(dims_sdf) - 1):
if layer + 1 in skip_in: if layer + 1 in skip_in:
out_dim = dims_sdf[layer + 1] - d_in out_dim = dims_sdf[layer + 1] - d_in
@ -43,7 +46,8 @@ class Decoder(nn.Module):
else: else:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
setattr(self, "sdf_"+str(layer), lin) self.sdf_modules.append(lin)
if geometric_init: if geometric_init:
if beta > 0: if beta > 0:
self.activation = nn.Softplus(beta=beta) self.activation = nn.Softplus(beta=beta)
@ -55,10 +59,6 @@ class Decoder(nn.Module):
self.activation = Sine() self.activation = Sine()
self.final_activation = nn.ReLU() self.final_activation = nn.ReLU()
# composite f_i to h
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
''' '''
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵 :param feature_matrix: 形状为 (B, P, D) 的特征矩阵
@ -73,10 +73,10 @@ class Decoder(nn.Module):
# 展平处理 (B*P, D) # 展平处理 (B*P, D)
x = feature_matrix.view(-1, D) x = feature_matrix.view(-1, D)
for layer in range(0, self.sdf_layers - 1): # 使用枚举遍历 sdf_modules
lin = getattr(self, "sdf_" + str(layer)) for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in: if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
x = lin(x) x = lin(x)
if layer < self.sdf_layers - 2: if layer < self.sdf_layers - 2:
@ -100,10 +100,9 @@ class Decoder(nn.Module):
# 直接使用输入的特征矩阵,因为形状已经是 (S, D) # 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix x = feature_matrix
for layer in range(0, self.sdf_layers - 1): for layer, lin in enumerate(self.sdf_modules):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in: if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2) x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
x = lin(x) x = lin(x)
if layer < self.sdf_layers - 2: if layer < self.sdf_layers - 2:

7
brep2sdf/networks/network.py

@ -100,9 +100,9 @@ class Net(nn.Module):
feature_vectors = self.encoder.forward(query_points,face_indices_mask) feature_vectors = self.encoder.forward(query_points,face_indices_mask)
print("feature_vector:", feature_vectors.shape) print("feature_vector:", feature_vectors.shape)
# 解码 # 解码
logger.gpu_memory_stats("encoder farward后") #logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P) f_i = self.decoder(feature_vectors) # (B, P)
logger.gpu_memory_stats("decoder farward后") #logger.gpu_memory_stats("decoder farward后")
output = f_i[:, 0] output = f_i[:, 0]
@ -127,7 +127,7 @@ class Net(nn.Module):
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.gpu_memory_stats("combine后") #logger.gpu_memory_stats("combine后")
return output return output
@torch.jit.export @torch.jit.export
@ -137,6 +137,7 @@ class Net(nn.Module):
surf_points (P, S): surf_points (P, S):
return (P, S) return (P, S)
""" """
logger.debug(surf_points)
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat) f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze() return f_i.squeeze()

Loading…
Cancel
Save