Browse Source

修复min max

final
mckay 1 month ago
parent
commit
ca9a56b198
  1. 17
      brep2sdf/networks/network.py

17
brep2sdf/networks/network.py

@ -101,11 +101,11 @@ class Net(nn.Module):
# 对 operator == 0 的样本取最大值 # 对 operator == 0 的样本取最大值
if mask_concave.any(): if mask_concave.any():
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值 # 对 operator == 1 的样本取最小值
if mask_convex.any(): if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values
#logger.gpu_memory_stats("combine后") #logger.gpu_memory_stats("combine后")
return output return output
@ -187,6 +187,19 @@ class Net(nn.Module):
return f_i.squeeze() return f_i.squeeze()
def freeze_stage1(self):
self.encoder.freeze_stage1()
def freeze_stage2(self):
self.encoder.freeze_stage2()
for param in self.decoder.parameters():
param.requires_grad = False
def unfreeze(self):
self.encoder.unfreeze()
for param in self.decoder.parameters():
param.requires_grad = True

Loading…
Cancel
Save