Browse Source

可以零表面 normal 降

final
mckay 2 weeks ago
parent
commit
73e7dc9fd9
  1. 9
      brep2sdf/IsoSurfacing.py
  2. 6
      brep2sdf/config/default_config.py
  3. 2
      brep2sdf/data/sampler.py
  4. 10
      brep2sdf/networks/decoder.py
  5. 60
      brep2sdf/networks/encoder.py
  6. 73
      brep2sdf/networks/feature_volume.py
  7. 65
      brep2sdf/networks/network.py
  8. 81
      brep2sdf/scripts/farward_speed.py
  9. 6
      brep2sdf/scripts/npz2points.py
  10. 5
      brep2sdf/test.py
  11. 11
      brep2sdf/train.py

9
brep2sdf/IsoSurfacing.py

@ -35,23 +35,28 @@ def predict_sdf(model, points, device, use_bk=False):
"""
points_t = torch.from_numpy(points).float().to(device)
logger.print_tensor_stats("input poitns", points_t)
#logger.info(f"points_t:{points_t.shape}")
with torch.no_grad():
if use_bk:
print("only background")
sdf = model.forward_background(points_t)
else:
batch_size = 8192*4 # 定义批量大小
batch_size = 8192*128 # 定义批量大小
sdf_list = [] # 用于存储批量预测结果
model.octree_module = model.octree_module.to(points_t.device)
for i in range(0, len(points), batch_size):
batch_points = points[i:i + batch_size]
points_t = torch.from_numpy(batch_points).float().to(device)
logger.print_tensor_stats("input points", points_t)
logger.print_tensor_stats("points_t", points_t)
batch_sdf = model(points_t)
logger.print_tensor_stats("batch_sdf", batch_sdf)
sdf_list.append(batch_sdf.cpu())
sdf = torch.cat(sdf_list) # 合并所有批量结果
#logger.info(f"sdf:{sdf.shape}")
logger.print_tensor_stats("sdf", sdf)
sdf = sdf.cpu().numpy().flatten()
#logger.info(f"sdf:{sdf.shape}")
return sdf
def extract_surface(sdf, xx, yy, zz, method='MC', bbox_size=1.0,feature_angle=30.0, voxel_size=0.01):

6
brep2sdf/config/default_config.py

@ -49,9 +49,9 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs1: int = 10000
num_epochs2: int = 0000
num_epochs3: int = 0000
num_epochs1: int = 0000
num_epochs2: int = 000
num_epochs3: int = 100
learning_rate: float = 0.1
learning_rate_schedule: List = field(default_factory=lambda: [{
"Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子

2
brep2sdf/data/sampler.py

@ -438,7 +438,7 @@ def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh):
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
grid_size = 2**5 + 1
grid_size = 2**4 + 1
start = -0.5
end = 0.5
x = np.linspace(start, end, grid_size)

10
brep2sdf/networks/decoder.py

@ -54,12 +54,8 @@ class Decoder(nn.Module):
self.norm_layers.append(nn.LayerNorm(dim))
if geometric_init:
self.activation = nn.Sequential(
nn.LayerNorm(out_dim), # 添加层归一化
nn.Softplus(beta=beta)
)
if beta > 0:
self.activation = nn.SiLU()
self.activation = nn.Softplus(beta=beta)
# vanilla relu
else:
self.activation = nn.ReLU()
@ -68,7 +64,7 @@ class Decoder(nn.Module):
self.activation = Sine()
self.final_activation = nn.Tanh()
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
def forward_patch(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵
B: 批大小
@ -98,7 +94,7 @@ class Decoder(nn.Module):
return f_i
@torch.jit.export
def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor:
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为(S, D) 的特征矩阵
S: 采样数量

60
brep2sdf/networks/encoder.py

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import numpy as np
import time
from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder
from brep2sdf.utils.logger import logger
@ -89,38 +90,43 @@ class Encoder(nn.Module):
volume_indices_mask: 关联的volume索引矩阵 (B, P)
返回:
特征张量 (B, P, D)
# 获取前两个有效特征索引
# 注意:当某行True数量小于2时:
# 1. 如果只有1个True,会重复获取该特征两次
# 2. 如果没有True,会获取背景特征两次(因为mask最后补充了一列True)
特征张量 (B, 2, D)
"""
batch_size, num_volumes = volume_indices_mask.shape
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
all_features = torch.zeros(batch_size, num_volumes+1, self.feature_dim,
device=query_points.device)
background_features = self.background.forward(query_points) # (B, D)
start_time = time.time()
# 创建 CUDA 流
streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))]
features_list = [None] * len(self.feature_volumes)
# 并行计算
# 遍历每个volume索引
for vol_id, volume in enumerate(self.feature_volumes):
mask = volume_indices_mask[:, vol_id].squeeze()
if not mask.any():
continue
with torch.cuda.stream(streams[vol_id]):
features = volume(query_points[mask])
features_list[vol_id] = (mask, features)
# 同步流
torch.cuda.synchronize()
# 写入结果
for vol_id, item in enumerate(features_list):
if item is None:
continue
mask, features = item
all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features
end_time = time.time()
logger.debug(f"duration:{end_time-start_time}")
return all_features
#logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}")
if mask.any():
# 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask])
all_features[mask, vol_id] = features
# 最后一维度作为背景场
all_features[:,num_volumes] = background_features
#all_features[:, :] = background_features.unsqueeze(1)
features = torch.zeros(batch_size, 2, self.feature_dim,
device=query_points.device)
# mask从 volume_indices_mask(B,P) 变成 (B,P+1) ,True 补充
mask = torch.cat([
volume_indices_mask,
torch.ones(batch_size, 1, dtype=torch.bool, device=volume_indices_mask.device)
], dim=1)
# 对于每个样本,取前两个非零特征, 如果没有
_, valid_indices = torch.topk(mask.float(), 2, dim=1) # (B, 2)
# 使用gather获取特征
features = all_features.gather(1, valid_indices.unsqueeze(-1).expand(-1, -1, self.feature_dim))
return features
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor:
"""
@ -146,10 +152,10 @@ class Encoder(nn.Module):
"""
# 获取 patch 特征
patch_features = self.feature_volumes[patch_id].forward(surf_points)
background_features = self.background.forward(surf_points) # (B, D)
#background_features = self.background.forward(surf_points) # (B, D)
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters()))
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件
return 0.1 * background_features + 0.9 * patch_features
return patch_features
def to(self, device):
super().to(device)

73
brep2sdf/networks/feature_volume.py

@ -45,36 +45,53 @@ class PatchFeatureVolume(nn.Module):
return self._batched_trilinear(normalized)
def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor:
"""批量处理的三线性插值"""
# 计算8个顶点的权重
"""
修复后的批量三线性插值
Args:
normalized (Tensor): [B, 3]归一化坐标范围 [0, 1]
Returns:
Tensor: [B, feature_dim]
"""
B = normalized.shape[0]
device = normalized.device
# 将归一化坐标映射到网格索引范围 [0, resolution - 1]
uvw = normalized * (self.resolution - 1)
indices = torch.floor(uvw).long() # (B,3)
weights = uvw - indices.float() # (B,3)
# 计算8个顶点的权重组合 (B,8)
weights = torch.stack([
(1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]),
(1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2],
(1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]),
(1 - weights[...,0]) * weights[...,1] * weights[...,2],
weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]),
weights[...,0] * (1 - weights[...,1]) * weights[...,2],
weights[...,0] * weights[...,1] * (1 - weights[...,2]),
weights[...,0] * weights[...,1] * weights[...,2],
], dim=-1) # (B,8)
# 获取8个顶点的特征 (B,8,D)
indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
], device=indices.device)
indices = torch.clamp(indices, 0, self.resolution-1)
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
# 加权求和 (B,D)
return torch.einsum('bnd,bn->bd', features, weights)
indices = torch.floor(uvw).long()
weights = uvw - indices.float()
# 确保所有维度对齐
indices = torch.clamp(indices, 0, self.resolution - 2) # 改为resolution-2防止越界
# 获取8个顶点的坐标
x, y, z = indices.unbind(dim=-1)
w_x, w_y, w_z = weights.unbind(dim=-1)
# 确保权重维度正确
w_z = w_z.unsqueeze(-1)
w_y = w_y.unsqueeze(-1)
w_x = w_x.unsqueeze(-1)
# 获取特征值
c00 = self.feature_volume[x, y, z ]
c01 = self.feature_volume[x, y, z + 1]
c10 = self.feature_volume[x, y + 1, z ]
c11 = self.feature_volume[x, y + 1, z + 1]
c20 = self.feature_volume[x + 1, y, z ]
c21 = self.feature_volume[x + 1, y, z + 1]
c30 = self.feature_volume[x + 1, y + 1, z ]
c31 = self.feature_volume[x + 1, y + 1, z + 1]
# 插值计算
c0 = c00 * (1 - w_z) + c01 * w_z
c1 = c10 * (1 - w_z) + c11 * w_z
c2 = c20 * (1 - w_z) + c21 * w_z
c3 = c30 * (1 - w_z) + c31 * w_z
c_top = c0 * (1 - w_y) + c1 * w_y
c_bot = c2 * (1 - w_y) + c3 * w_y
return c_top * (1 - w_x) + c_bot * w_x
class SimpleFeatureEncoder(nn.Module):
def __init__(self, input_dim=3, feature_dim=64):

65
brep2sdf/networks/network.py

@ -79,36 +79,25 @@ class Net(nn.Module):
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
#skip_in=(3,),
geometric_init=True,
beta=5
beta=100
)
#self.csg_combiner = CSGCombiner(flag_convex=True)
def process_sdf(self,f_i, face_indices_mask, operator):
def process_sdf(self, f_i, operator):
output = f_i[:,0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
# 填充到固定大小 (B, max_patches)
padded_f_i[:, :2] = valid_values # (B, max_patches)
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最
# 对 operator == 0 的样本取最小值
if mask_concave.any():
output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values
output[mask_concave] = torch.min(f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最
# 对 operator == 1 的样本取最大值
if mask_convex.any():
output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values
output[mask_convex] = torch.max(f_i[mask_convex], dim=1).values
#logger.gpu_memory_stats("combine后")
return output
@torch.jit.export
@ -125,18 +114,7 @@ class Net(nn.Module):
#logger.debug("step octree")
_,face_indices_mask,operator = self.octree_module.forward(query_points)
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape)
# 解码
#logger.debug("step decode")
#logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return f_i[:,0]
return self.process_sdf(f_i, face_indices_mask, operator)
return self.forward_without_octree(query_points,face_indices_mask,operator)
@torch.jit.export
def forward_background(self, query_points):
@ -152,31 +130,36 @@ class Net(nn.Module):
# 编码
feature_vectors = self.encoder.forward_background(query_points)
# 解码
h = self.decoder.forward_training_volumes(feature_vectors) # (B, D)
h = self.decoder.forward(feature_vectors) # (B, D)
return h
@torch.jit.ignore
def forward_without_octree(self, query_points,face_indices_mask,operator):
@torch.jit.export
def forward_without_octree(self, query_points, face_indices_mask, operator):
"""
前向传播
参数:
query_point: 查询点的位置坐标
query_points: 查询点的位置坐标 (B, 3)
face_indices_mask: 面索引掩码 (B, P)
operator: 操作符 (B,)
返回:
output: 解码后的输出结果
output: 解码后的SDF值 (B,)
"""
# 批量查询所有点的索引和bbox
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder(query_points,face_indices_mask)
feature_vectors = self.encoder(query_points,face_indices_mask) # (B, 2, D)
feature_dim = feature_vectors.size(-1) # 获取特征维度
flatten_feature_vectors = feature_vectors.reshape(-1, feature_dim) # (B*2, D)
#print("feature_vector:", feature_vectors.shape)
# 解码
f_i = self.decoder(feature_vectors) # (B, P)
f_i = self.decoder(flatten_feature_vectors) # (B*2, )
B = feature_vectors.size(0) # 获取batch size
f_i = f_i.reshape(B, 2) # 将输出reshape为(B, 2)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return f_i[:,0]
return self.process_sdf(f_i, face_indices_mask, operator)
return self.process_sdf(f_i, operator)
@torch.jit.ignore
def forward_training_volumes(self, surf_points, patch_id:int):
@ -186,7 +169,7 @@ class Net(nn.Module):
return (P, S)
"""
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat)
f_i = self.decoder.forward(feature_mat)
return f_i.squeeze()

81
brep2sdf/scripts/farward_speed.py

@ -0,0 +1,81 @@
import re
for_log_data = """
2025-05-21 17:35:04,438 | DEBUG  | encoder.py:forward:108 - duration:0.02291393280029297
2025-05-21 17:35:04,487 | DEBUG  | encoder.py:forward:108 - duration:0.013659954071044922
2025-05-21 17:35:05,096 | DEBUG  | encoder.py:forward:108 - duration:0.013151884078979492
2025-05-21 17:35:05,128 | DEBUG  | encoder.py:forward:108 - duration:0.012245893478393555
2025-05-21 17:35:05,667 | DEBUG  | encoder.py:forward:108 - duration:0.012324810028076172
2025-05-21 17:35:05,698 | DEBUG  | encoder.py:forward:108 - duration:0.011858940124511719
2025-05-21 17:35:06,247 | DEBUG  | encoder.py:forward:108 - duration:0.013593196868896484
2025-05-21 17:35:06,278 | DEBUG  | encoder.py:forward:108 - duration:0.012105226516723633
2025-05-21 17:35:06,829 | DEBUG  | encoder.py:forward:108 - duration:0.012081146240234375
2025-05-21 17:35:06,859 | DEBUG  | encoder.py:forward:108 - duration:0.011334419250488281
2025-05-21 17:35:07,404 | DEBUG  | encoder.py:forward:108 - duration:0.013489246368408203
2025-05-21 17:35:07,436 | DEBUG  | encoder.py:forward:108 - duration:0.01230931282043457
2025-05-21 17:35:07,983 | DEBUG  | encoder.py:forward:108 - duration:0.01315164566040039
2025-05-21 17:35:08,015 | DEBUG  | encoder.py:forward:108 - duration:0.012539148330688477
2025-05-21 17:35:08,569 | DEBUG  | encoder.py:forward:108 - duration:0.014146566390991211
2025-05-21 17:35:08,602 | DEBUG  | encoder.py:forward:108 - duration:0.013015508651733398
2025-05-21 17:35:09,156 | DEBUG  | encoder.py:forward:108 - duration:0.01263570785522461
2025-05-21 17:35:09,186 | DEBUG  | encoder.py:forward:108 - duration:0.011255264282226562
2025-05-21 17:35:09,722 | DEBUG  | encoder.py:forward:108 - duration:0.014206647872924805
2025-05-21 17:35:09,754 | DEBUG  | encoder.py:forward:108 - duration:0.012360095977783203
2025-05-21 17:35:10,307 | DEBUG  | encoder.py:forward:108 - duration:0.013350963592529297
2025-05-21 17:35:10,339 | DEBUG  | encoder.py:forward:108 - duration:0.012225151062011719
2025-05-21 17:35:10,894 | DEBUG  | encoder.py:forward:108 - duration:0.014019250869750977
2025-05-21 17:35:10,925 | DEBUG  | encoder.py:forward:108 - duration:0.012645483016967773
2025-05-21 17:35:11,477 | DEBUG  | encoder.py:forward:108 - duration:0.010942935943603516
2025-05-21 17:35:11,494 | DEBUG  | encoder.py:forward:108 - duration:0.010617733001708984
"""
parallel_log_data = """
2025-05-21 17:25:30,716 | DEBUG  | encoder.py:forward:122 - duration:0.014799833297729492
2025-05-21 17:25:30,748 | DEBUG  | encoder.py:forward:122 - duration:0.013928413391113281
2025-05-21 17:25:31,318 | DEBUG  | encoder.py:forward:122 - duration:0.020897626876831055
2025-05-21 17:25:31,352 | DEBUG  | encoder.py:forward:122 - duration:0.013567924499511719
2025-05-21 17:25:31,929 | DEBUG  | encoder.py:forward:122 - duration:0.020887374877929688
2025-05-21 17:25:31,965 | DEBUG  | encoder.py:forward:122 - duration:0.014947652816772461
2025-05-21 17:25:32,550 | DEBUG  | encoder.py:forward:122 - duration:0.02316737174987793
2025-05-21 17:25:32,586 | DEBUG  | encoder.py:forward:122 - duration:0.01513051986694336
2025-05-21 17:25:33,172 | DEBUG  | encoder.py:forward:122 - duration:0.021285295486450195
2025-05-21 17:25:33,207 | DEBUG  | encoder.py:forward:122 - duration:0.015576839447021484
2025-05-21 17:25:33,790 | DEBUG  | encoder.py:forward:122 - duration:0.02099466323852539
2025-05-21 17:25:33,826 | DEBUG  | encoder.py:forward:122 - duration:0.015471696853637695
2025-05-21 17:25:34,406 | DEBUG  | encoder.py:forward:122 - duration:0.021028518676757812
2025-05-21 17:25:34,441 | DEBUG  | encoder.py:forward:122 - duration:0.015815019607543945
2025-05-21 17:25:35,034 | DEBUG  | encoder.py:forward:122 - duration:0.020988941192626953
2025-05-21 17:25:35,070 | DEBUG  | encoder.py:forward:122 - duration:0.01592278480529785
2025-05-21 17:25:35,662 | DEBUG  | encoder.py:forward:122 - duration:0.019669532775878906
2025-05-21 17:25:35,698 | DEBUG  | encoder.py:forward:122 - duration:0.015323638916015625
2025-05-21 17:25:36,276 | DEBUG  | encoder.py:forward:122 - duration:0.02336907386779785
2025-05-21 17:25:36,311 | DEBUG  | encoder.py:forward:122 - duration:0.015668869018554688
2025-05-21 17:25:36,896 | DEBUG  | encoder.py:forward:122 - duration:0.022051572799682617
2025-05-21 17:25:36,932 | DEBUG  | encoder.py:forward:122 - duration:0.015897512435913086
2025-05-21 17:25:37,526 | DEBUG  | encoder.py:forward:122 - duration:0.020981311798095703
2025-05-21 17:25:37,560 | DEBUG  | encoder.py:forward:122 - duration:0.015113353729248047
2025-05-21 17:25:38,137 | DEBUG  | encoder.py:forward:122 - duration:0.018566131591796875
2025-05-21 17:25:38,157 | DEBUG  | encoder.py:forward:122 - duration:0.013988733291625977
"""
# 计算 duration 平均值
def run(log_data):
# 使用正则表达式提取所有的 duration 数值
durations = re.findall(r'duration:(\d+\.\d+)', log_data)
# 转换为浮点数列表
durations = [float(d) for d in durations]
# 计算平均值
average_duration = sum(durations) / len(durations) if durations else 0
# 输出结果
print(f"共找到 {len(durations)} 个 duration")
print(f"平均 duration: {average_duration:.6f}")
return average_duration
speed1 = run(for_log_data)
speed2 = run(parallel_log_data)
print(f"for speed: {speed1:.6f}")
print(f"parallel speed: {speed2:.6f}")

6
brep2sdf/scripts/npz2points.py

@ -11,10 +11,10 @@ def load_brep_file(brep_path):
if __name__ == "__main__":
data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000003.xyz")
surfs =data["train_surf_ncs"]
data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz")
surfs =data["sampled_points_normals_sdf"]
print(surfs)
with open("0003_t.xyz","w") as f:
with open("0031_t.xyz","w") as f:
for point in surfs:
#f.write(f"{point[0]} {point[1]} {point[2]}\n")
f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n")

5
brep2sdf/test.py

@ -164,7 +164,7 @@ def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh):
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
grid_size = 2**5 + 1
grid_size = 2**4 + 1
start = -1
end = 1
x = np.linspace(start, end, grid_size)
@ -199,6 +199,7 @@ def test2(obj_file):
# 将点坐标和SDF值转换为网格格式
grid_size = int(np.cbrt(len(points))) # 假设采样点是立方体网格
print(f"grid size:{grid_size}")
sdf_grid = sdf_values.reshape((grid_size, grid_size, grid_size))
# 使用Marching Cubes提取零表面
@ -243,5 +244,5 @@ def main():
if __name__ == "__main__":
#main()
#test()
test2("/home/wch/brep2sdf/data/gt_mesh/00000003.obj")
test2("/home/wch/brep2sdf/data/gt_mesh/00000031.obj")
# python test.py -i /home/wch/brep2sdf/data/gt_mesh/00000003.obj -o output.ply --depth 6 --box_size 2.0 --method MC

11
brep2sdf/train.py

@ -604,7 +604,7 @@ class Trainer:
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 4096*5 # 设置合适的batch大小
batch_size = 50000 # 设置合适的batch大小
# 数据处理
# manfld
@ -640,7 +640,7 @@ class Trainer:
_nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"]
_nonmnfld_operator = self.cached_train_data["nonmnfld_operator"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
#logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
@ -747,10 +747,11 @@ class Trainer:
# 记录训练进度 (只记录有效的损失)
if epoch % 10 == 0:
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {total_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
#self.validate(epoch,total_loss)
self.validate(epoch,total_loss)
return total_loss # 对于单批次训练,直接返回当前损失
@ -932,8 +933,8 @@ class Trainer:
#stage 3
self.scheduler.reset()
self.model.freeze_stage2()
#self.model.unfreeze()
#self.model.freeze_stage2()
self.model.unfreeze()
for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage3(epoch)

Loading…
Cancel
Save