|
|
@ -101,11 +101,11 @@ class Net(nn.Module): |
|
|
|
|
|
|
|
# 对 operator == 0 的样本取最大值 |
|
|
|
if mask_concave.any(): |
|
|
|
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values |
|
|
|
output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values |
|
|
|
|
|
|
|
# 对 operator == 1 的样本取最小值 |
|
|
|
if mask_convex.any(): |
|
|
|
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values |
|
|
|
output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values |
|
|
|
|
|
|
|
#logger.gpu_memory_stats("combine后") |
|
|
|
return output |
|
|
@ -187,6 +187,19 @@ class Net(nn.Module): |
|
|
|
return f_i.squeeze() |
|
|
|
|
|
|
|
|
|
|
|
def freeze_stage1(self): |
|
|
|
self.encoder.freeze_stage1() |
|
|
|
|
|
|
|
def freeze_stage2(self): |
|
|
|
self.encoder.freeze_stage2() |
|
|
|
for param in self.decoder.parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def unfreeze(self): |
|
|
|
self.encoder.unfreeze() |
|
|
|
for param in self.decoder.parameters(): |
|
|
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|