From ca9a56b198f632b0a2510cb4b16da0e2f321a5e6 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 5 May 2025 19:21:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dmin=20max?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/network.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 0d82a43..babdddb 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -101,11 +101,11 @@ class Net(nn.Module): # 对 operator == 0 的样本取最大值 if mask_concave.any(): - output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values + output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values # 对 operator == 1 的样本取最小值 if mask_convex.any(): - output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values + output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values #logger.gpu_memory_stats("combine后") return output @@ -187,6 +187,19 @@ class Net(nn.Module): return f_i.squeeze() + def freeze_stage1(self): + self.encoder.freeze_stage1() + + def freeze_stage2(self): + self.encoder.freeze_stage2() + for param in self.decoder.parameters(): + param.requires_grad = False + + def unfreeze(self): + self.encoder.unfreeze() + for param in self.decoder.parameters(): + param.requires_grad = True +