6 changed files with 507 additions and 60 deletions
@ -1,66 +1,161 @@ |
|||
import os |
|||
import subprocess |
|||
from tqdm import tqdm |
|||
import numpy as np |
|||
import torch |
|||
import argparse |
|||
from skimage import measure |
|||
import time |
|||
import trimesh |
|||
|
|||
# 使用一个 c++ 程序处理,这里只是调用,注意要在docker里面运行。宿主机编译失败 |
|||
def create_grid(depth, box_size): |
|||
""" |
|||
创建三维网格点 |
|||
:param depth: 网格深度(决定分辨率) |
|||
:param box_size: 边界框大小(边长) |
|||
:return: 网格点数组和坐标网格 |
|||
""" |
|||
grid_size = 2**depth + 1 |
|||
start = -box_size / 2 |
|||
end = box_size / 2 |
|||
x = np.linspace(start, end, grid_size) |
|||
y = np.linspace(start, end, grid_size) |
|||
z = np.linspace(start, end, grid_size) |
|||
xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') |
|||
points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) |
|||
return points, xx, yy, zz |
|||
|
|||
def predict_sdf(model, points, device): |
|||
""" |
|||
使用模型预测SDF值 |
|||
:param model: PyTorch模型 |
|||
:param points: 输入点坐标 (N, 3) |
|||
:param device: 设备(CPU/GPU) |
|||
:return: SDF值数组 (N,) |
|||
""" |
|||
points_t = torch.from_numpy(points).float().to(device) |
|||
|
|||
with torch.no_grad(): |
|||
sdf = model(points_t).cpu().numpy().flatten() |
|||
return sdf |
|||
|
|||
def extract_surface(sdf, xx, yy, zz, method='MC'): |
|||
""" |
|||
提取零表面 |
|||
:param sdf: SDF值三维数组 |
|||
:param xx/yy/zz: 网格坐标 |
|||
:param method: 提取方法(MC: Marching Cubes) |
|||
:return: 顶点和面片 |
|||
""" |
|||
if method == 'MC': |
|||
verts, faces, _, _ = measure.marching_cubes(sdf, level=0) |
|||
else: |
|||
raise NotImplementedError("仅支持Marching Cubes方法") |
|||
return verts, faces |
|||
|
|||
def save_ply(vertices, faces, filename): |
|||
""" |
|||
保存顶点和面片为PLY文件 |
|||
:param vertices: 顶点数组 (N, 3) |
|||
:param faces: 面片数组 (M, 3) |
|||
:param filename: 输出文件名 |
|||
""" |
|||
with open(filename, 'w') as f: |
|||
f.write("ply\n") |
|||
f.write("format ascii 1.0\n") |
|||
f.write(f"element vertex {len(vertices)}\n") |
|||
f.write("property float x\n") |
|||
f.write("property float y\n") |
|||
f.write("property float z\n") |
|||
f.write(f"element face {len(faces)}\n") |
|||
f.write("property list uchar int vertex_indices\n") |
|||
f.write("end_header\n") |
|||
for v in vertices: |
|||
f.write(f"{v[0]} {v[1]} {v[2]}\n") |
|||
for face in faces: |
|||
f.write(f"3 {face[0]} {face[1]} {face[2]}\n") |
|||
|
|||
def compute_sdf_error(model, gt_mesh, res, device): |
|||
""" |
|||
计算预测SDF与GT网格的误差 |
|||
:param model: PyTorch模型 |
|||
:param gt_mesh: GT网格(Trimesh格式) |
|||
:param res: 误差计算分辨率 |
|||
:param device: 设备 |
|||
:return: 平均误差和最大误差 |
|||
""" |
|||
# 生成均匀采样点 |
|||
box_size = max(gt_mesh.extents) |
|||
start = -box_size / 2 |
|||
end = box_size / 2 |
|||
x = np.linspace(start, end, res) |
|||
y = np.linspace(start, end, res) |
|||
z = np.linspace(start, end, res) |
|||
points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) |
|||
|
|||
# 预测SDF |
|||
pred_sdf = predict_sdf(model, points, device) |
|||
|
|||
# 计算GT距离 |
|||
distances = gt_mesh.nearest.on_surface(points)[1] |
|||
gt_sdf = np.abs(distances) |
|||
|
|||
# 计算误差 |
|||
abs_error = np.abs(pred_sdf - gt_sdf) |
|||
rel_error = abs_error / (np.abs(gt_sdf) + 1e-9) |
|||
avg_abs = np.mean(abs_error) |
|||
avg_rel = np.mean(rel_error) |
|||
max_abs = np.max(abs_error) |
|||
max_rel = np.max(rel_error) |
|||
|
|||
return avg_abs, avg_rel, max_abs, max_rel |
|||
|
|||
def main(): |
|||
# 定义 STEP 文件目录和名称列表文件路径 |
|||
output_data_root_dir = "/workspace/home/wch/brep2sdf/data/output_data" |
|||
name_list_path = "/workspace/home/wch/brep2sdf/data/name_list.txt" |
|||
|
|||
# 读取名称列表 |
|||
try: |
|||
with open(name_list_path, 'r') as f: |
|||
names = [line.strip() for line in f if line.strip()] # 去除空行 |
|||
except FileNotFoundError: |
|||
print(f"Error: File '{name_list_path}' not found.") |
|||
return |
|||
except Exception as e: |
|||
print(f"Error reading file '{name_list_path}': {e}") |
|||
return |
|||
|
|||
# 遍历名称列表并处理每个 STEP 文件 |
|||
for name in tqdm(names, desc="ISOsurfing pt files"): |
|||
pt_file = os.path.join(output_data_root_dir, f"{name}.pt") |
|||
|
|||
if not pt_file: |
|||
print(f"Warning: No pt files found in directory '{output_data_root_dir}'. Skipping...") |
|||
continue |
|||
|
|||
|
|||
# ./ISG_console_pytorch -i ./test/teaser.pt -o outputmesh.ply -v -0.01 -d 8 |
|||
# 构造子进程命令 |
|||
command = [ |
|||
"python", "/workspace/home/wch/brep2sdf/data/scripts/IsoSurfacing/build/App/console_pytorch/ISG_console_pytorch", |
|||
"-i", pt_file, # 使用当前遍历的pt文件 |
|||
"-o", os.path.join(output_data_root_dir, f"{name}_outputmesh.ply"), # 动态生成输出文件路径 |
|||
"-v", "-0.01", "-d", "8" |
|||
] |
|||
|
|||
# 调用子进程运行命令 |
|||
try: |
|||
result = subprocess.run( |
|||
command, |
|||
capture_output=True, |
|||
text=True, |
|||
check=True # 如果返回非零退出码,则抛出 CalledProcessError |
|||
parser = argparse.ArgumentParser(description='IsoSurface Generator') |
|||
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') |
|||
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') |
|||
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') |
|||
parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小') |
|||
parser.add_argument('--method', type=str, default='MC', choices=['MC'], help='表面提取方法') |
|||
parser.add_argument('--use-gpu', action='store_true', help='使用GPU') |
|||
parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') |
|||
parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率') |
|||
args = parser.parse_args() |
|||
|
|||
# 设置设备 |
|||
device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") |
|||
print(f"Using device: {device}") |
|||
|
|||
model = torch.jit.load(args.input).to(device) |
|||
#model = torch.load(args.input).to(device) |
|||
model.eval() |
|||
|
|||
# 创建网格并预测SDF |
|||
points, xx, yy, zz = create_grid(args.depth, args.box_size) |
|||
sdf = predict_sdf(model, points, device) |
|||
print(points.shape) |
|||
print(sdf.shape) |
|||
print(sdf) |
|||
sdf_grid = sdf.reshape(xx.shape) |
|||
|
|||
# 提取表面 |
|||
print("Extracting surface...") |
|||
start_time = time.time() |
|||
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) |
|||
print(f"Surface extraction took {time.time() - start_time:.2f} seconds") |
|||
|
|||
# 保存网格 |
|||
save_ply(verts, faces, args.output) |
|||
print(f"Mesh saved to {args.output}") |
|||
|
|||
# 误差评估(可选) |
|||
if args.compare: |
|||
print("Computing SDF error...") |
|||
gt_mesh = trimesh.load(args.compare) |
|||
avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( |
|||
model, gt_mesh, args.compres, device |
|||
) |
|||
print(f"Successfully processed '{name}'") |
|||
print("STDOUT:", result.stdout) |
|||
print("STDERR:", result.stderr) |
|||
except subprocess.CalledProcessError as e: |
|||
print(f"Error processing '{name}': Command failed with return code {e.returncode}") |
|||
print(f"Command: {e.cmd}") |
|||
print(f"Error type: {type(e).__name__}") |
|||
print("STDOUT:", e.stdout) |
|||
print("STDERR:", e.stderr) |
|||
print("Traceback:", e.__traceback__) |
|||
except Exception as e: |
|||
print(f"Unexpected error processing '{name}': {str(e)}") |
|||
print(f"Command: {command}") |
|||
print("Traceback:", traceback.format_exc()) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
print(f"Average Absolute Error: {avg_abs:.4f}") |
|||
print(f"Average Relative Error: {avg_rel:.4f}") |
|||
print(f"Max Absolute Error: {max_abs:.4f}") |
|||
print(f"Max Relative Error: {max_rel:.4f}") |
|||
|
|||
if __name__ == "__main__": |
|||
main() |
@ -0,0 +1,67 @@ |
|||
import os |
|||
import subprocess |
|||
from tqdm import tqdm |
|||
|
|||
|
|||
def main(): |
|||
# 定义 STEP 文件目录和名称列表文件路径 |
|||
step_root_dir = "/home/wch/brep2sdf/data/step" |
|||
name_list_path = "/home/wch/brep2sdf/data/name_list.txt" |
|||
|
|||
# 读取名称列表 |
|||
try: |
|||
with open(name_list_path, 'r') as f: |
|||
names = [line.strip() for line in f if line.strip()] # 去除空行 |
|||
except FileNotFoundError: |
|||
print(f"Error: File '{name_list_path}' not found.") |
|||
return |
|||
except Exception as e: |
|||
print(f"Error reading file '{name_list_path}': {e}") |
|||
return |
|||
|
|||
# 遍历名称列表并处理每个 STEP 文件 |
|||
for name in tqdm(names, desc="Processing STEP files"): |
|||
step_dir = os.path.join(step_root_dir, name) |
|||
|
|||
# 动态生成 STEP 文件路径(假设只有一个文件) |
|||
step_files = [ |
|||
os.path.join(step_dir, f) |
|||
for f in os.listdir(step_dir) |
|||
if f.endswith(".step") and f.startswith(name) |
|||
] |
|||
|
|||
if not step_files: |
|||
print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...") |
|||
continue |
|||
|
|||
# 假设我们只处理第一个匹配的文件 |
|||
input_step = step_files[0] |
|||
|
|||
# 构造子进程命令 |
|||
command = [ |
|||
"python", "train.py", |
|||
"--use-normal", |
|||
"-i", input_step, # 输入文件路径 |
|||
] |
|||
|
|||
# 调用子进程运行命令 |
|||
try: |
|||
result = subprocess.run( |
|||
command, |
|||
capture_output=True, |
|||
text=True, |
|||
check=True # 如果返回非零退出码,则抛出 CalledProcessError |
|||
) |
|||
print(f"Processed '{input_step}' successfully.") |
|||
print("STDOUT:", result.stdout) |
|||
print("STDERR:", result.stderr) |
|||
except subprocess.CalledProcessError as e: |
|||
print(f"Error processing '{input_step}': {e}") |
|||
print("STDOUT:", e.stdout) |
|||
print("STDERR:", e.stderr) |
|||
except Exception as e: |
|||
print(f"Unexpected error processing '{input_step}': {e}") |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
main() |
@ -0,0 +1,280 @@ |
|||
import os |
|||
import sys |
|||
from brep2sdf.utils.logger import logger |
|||
|
|||
|
|||
# 导入日志系统 |
|||
from brep2sdf.utils.logger import logger |
|||
import numpy as np |
|||
from scipy.spatial import cKDTree |
|||
from scipy.spatial.distance import directed_hausdorff |
|||
import trimesh |
|||
import pandas as pd |
|||
import csv |
|||
import math |
|||
import pickle |
|||
|
|||
import argparse |
|||
|
|||
project_dir = "/home/wch/brep2sdf" |
|||
# parse args first and set gpu id |
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--gt_path', type=str, |
|||
default=os.path.join(project_dir, 'data/gt_point'), |
|||
help='ground truth data path') |
|||
parser.add_argument('--pred_path', type=str, |
|||
default=os.path.join(project_dir, 'data/output_data'), |
|||
help='converted data path') |
|||
parser.add_argument('--name_list', type=str, default='name_list.txt', help='names of models to be evaluated, if you want to evaluate the whole dataset, please set it as all_names.txt') |
|||
parser.add_argument('--nsample', type=int, default=50000, help='point batch size') |
|||
parser.add_argument('--regen', default = False, action="store_true", help = 'regenerate feature curves') |
|||
parser.add_argument('--csv_name', type=str, default='eval_results.csv', help='csv file name') |
|||
args = parser.parse_args() |
|||
|
|||
def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): |
|||
''' Computes minimal distances of each point in points_src to points_tgt. |
|||
|
|||
Args: |
|||
points_src (numpy array [N, 3]): source points |
|||
normals_src (numpy array [N, 3]): source normals |
|||
points_tgt (numpy array [M, 3]): target points |
|||
normals_tgt (numpy array [M, 3]): target |
|||
Returns: |
|||
dist (numpy array [N]): minimal distances of each point in points_src to points_tgt |
|||
normals_dot_product (numpy array [N]): dot product of normals of points_src and points_tgt |
|||
''' |
|||
kdtree = cKDTree(points_tgt) |
|||
dist, idx = kdtree.query(points_src) |
|||
|
|||
if normals_src is not None and normals_tgt is not None: |
|||
normals_src = \ |
|||
normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) |
|||
normals_tgt = \ |
|||
normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) |
|||
|
|||
normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) |
|||
# Handle normals that point into wrong direction gracefully |
|||
# (mostly due to mehtod not caring about this in generation) |
|||
normals_dot_product = np.abs(normals_dot_product) |
|||
return dist, normals_dot_product |
|||
|
|||
def distance_feature2mesh(points, mesh): |
|||
prox = trimesh.proximity.ProximityQuery(mesh) |
|||
signed_distance = prox.signed_distance(points) |
|||
return np.abs(signed_distance) |
|||
|
|||
def distance_p2mesh(points_src, normals_src, mesh): |
|||
points_tgt, idx = mesh.sample(args.nsample, return_index=True) |
|||
points_tgt = points_tgt.astype(np.float32) |
|||
normals_tgt = mesh.face_normals[idx] |
|||
cd1, nc1 = distance_p2p(points_src, normals_src, points_tgt, normals_tgt) #pred2gt |
|||
hd1 = cd1.max() |
|||
cd1 = cd1.mean() |
|||
|
|||
nc1 = np.clip(nc1, -1.0, 1.0) |
|||
angles1 = np.arccos(nc1) / math.pi * 180.0 |
|||
angles1_mean = angles1.mean() |
|||
angles1_std = np.std(angles1) |
|||
|
|||
cd2, nc2 = distance_p2p(points_tgt, normals_tgt, points_src, normals_src) #gt2pred |
|||
hd2 = cd2.max() |
|||
cd2 = cd2.mean() |
|||
nc2 = np.clip(nc2, -1.0, 1.0) |
|||
|
|||
angles2 = np.arccos(nc2)/ math.pi * 180.0 |
|||
angles2_mean = angles2.mean() |
|||
angles2_std = np.std(angles2) |
|||
|
|||
|
|||
cd = 0.5 * (cd1 + cd2) |
|||
hd = max(hd1, hd2) |
|||
angles_mean = 0.5 * (angles1_mean + angles2_mean) |
|||
angles_std = 0.5 * (angles1_std + angles2_std) |
|||
return cd, hd, angles_mean, angles_std, hd1, hd2 |
|||
|
|||
|
|||
def distance_fea(gt_pa, pred_pa): |
|||
"""计算特征点之间的距离和角度差异 |
|||
Args: |
|||
gt_pa: 真实特征点和角度 [N, 4] |
|||
pred_pa: 预测特征点和角度 [N, 4] |
|||
Returns: |
|||
dfg2p: 真实到预测的距离 |
|||
dfp2g: 预测到真实的距离 |
|||
fag2p: 真实到预测的角度差 |
|||
fap2g: 预测到真实的角度差 |
|||
""" |
|||
gt_points = gt_pa[:,:3] |
|||
pred_points = pred_pa[:,:3] |
|||
gt_angle = gt_pa[:,3] |
|||
pred_angle = pred_pa[:,3] |
|||
dfg2p = 0.0 |
|||
dfp2g = 0.0 |
|||
fag2p = 0.0 |
|||
fap2g = 0.0 |
|||
pred_kdtree = cKDTree(pred_points) |
|||
dist1, idx1 = pred_kdtree.query(gt_points) |
|||
dfg2p = dist1.mean() |
|||
assert(idx1.shape[0] == gt_points.shape[0]) |
|||
fag2p = np.abs(gt_angle - pred_angle[idx1]) |
|||
|
|||
gt_kdtree = cKDTree(gt_points) |
|||
dist2, idx2 = gt_kdtree.query(pred_points) |
|||
dfp2g = dist2.mean() |
|||
fap2g = np.abs(pred_angle - gt_angle[idx2]) |
|||
|
|||
fag2p = fag2p.mean() |
|||
fap2g = fap2g.mean() |
|||
|
|||
return dfg2p, dfp2g, fag2p, fap2g |
|||
|
|||
def load_and_process_single_model(line, gt_path, pred_mesh_path, args): |
|||
"""处理单个模型的评估 |
|||
Args: |
|||
line (str): 模型名称 |
|||
gt_path (str): 真值路径 |
|||
pred_mesh_path (str): 预测网格路径 |
|||
args: 参数配置 |
|||
Returns: |
|||
dict: 包含该模型所有评估指标的字典 |
|||
""" |
|||
try: |
|||
#line = line.strip()[:-4] # 不用去 _50k |
|||
result = {'name': line} |
|||
|
|||
# 加载点云数据 |
|||
test_xyz = os.path.join(gt_path, line+'_50k.xyz') |
|||
try: |
|||
ptnormal = np.loadtxt(test_xyz) |
|||
except FileNotFoundError: |
|||
logger.error(f"XYZ file not found: {test_xyz}") |
|||
return None |
|||
except IOError as e: |
|||
logger.error(f"Error reading XYZ file {test_xyz}: {str(e)}") |
|||
return None |
|||
except ValueError as e: |
|||
logger.error(f"Invalid data format in XYZ file {test_xyz}: {str(e)}") |
|||
return None |
|||
except Exception as e: |
|||
logger.error(f"Unexpected error loading {test_xyz}: {str(e)}") |
|||
return None |
|||
logger.debug("successfully load gt points.") |
|||
|
|||
# 加载预测网格 |
|||
meshfile = os.path.join(pred_mesh_path, '{}.ply'.format(line)) |
|||
if not os.path.exists(meshfile): |
|||
logger.warning(f'File not exists: {meshfile}, try to generate it...') |
|||
pt_file = os.path.join(pred_mesh_path, '{}.pt'.format(line)) |
|||
try: |
|||
# 记录开始执行命令 |
|||
logger.debug(f"Executing IsoSurfacing: python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu") |
|||
|
|||
# 执行命令并检查返回值 |
|||
ret = os.system(f"python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu") |
|||
|
|||
if ret != 0: |
|||
raise RuntimeError(f"IsoSurfacing failed with return code {ret}") |
|||
|
|||
# 检查输出文件是否生成 |
|||
if not os.path.exists(meshfile): |
|||
raise FileNotFoundError(f"Output mesh file not created: {meshfile}") |
|||
|
|||
logger.debug("IsoSurfacing completed successfully") |
|||
|
|||
except FileNotFoundError as e: |
|||
logger.error(f"IsoSurfacing input file not found: {str(e)}") |
|||
return None |
|||
except RuntimeError as e: |
|||
logger.error(f"IsoSurfacing execution failed: {str(e)}") |
|||
return None |
|||
except Exception as e: |
|||
logger.error(f"Unexpected error in IsoSurfacing: {str(e)}") |
|||
return None |
|||
|
|||
# 检查缓存 |
|||
stat_file = meshfile + "_stat" |
|||
if not args.regen and os.path.exists(stat_file) and os.path.getsize(stat_file) > 0: |
|||
with open(stat_file, 'rb') as f: |
|||
return pickle.load(f) |
|||
|
|||
# 计算网格距离指标 |
|||
mesh = trimesh.load(meshfile) |
|||
logger.debug("successfully load pred mesh.") |
|||
cd, hd, adm, ads, hd_pred2gt, hd_gt2pred = distance_p2mesh( |
|||
ptnormal[:,:3], ptnormal[:,3:6], mesh) |
|||
|
|||
result.update({ |
|||
'CD': cd, 'HD': hd, 'HDpred2gt': hd_pred2gt, |
|||
'HDgt2pred': hd_gt2pred, 'AngleDiffMean': adm, |
|||
'AngleDiffStd': ads |
|||
}) |
|||
|
|||
# 计算特征点指标 |
|||
gt_ptangle = np.loadtxt(os.path.join(gt_path, line + '.ptangle')) |
|||
pred_ptangle_path = meshfile[:-4]+'.ptangle' |
|||
|
|||
if not os.path.exists(pred_ptangle_path) or args.regen: |
|||
os.system('/home/wch/brep2sdf/data/scripts/MeshFeatureSample/build/SimpleSample -i {} -o {} -s 4e-3'.format(meshfile, pred_ptangle_path)) |
|||
|
|||
pred_ptangle = np.loadtxt(pred_ptangle_path).reshape(-1,4) |
|||
|
|||
# 处理特征点结果 |
|||
if len(gt_ptangle) == 0 or len(pred_ptangle) == 0: |
|||
result.update({ |
|||
'FeaDfgt2pred': 0.0, 'FeaDfpred2gt': 0.0, |
|||
'FeaAnglegt2pred': 0.0, 'FeaAnglepred2gt': 0.0, |
|||
'FeaDf': 0.0, 'FeaAngle': 0.0 |
|||
}) |
|||
else: |
|||
dfg2p, dfp2g, fag2p, fap2g = distance_fea(gt_ptangle, pred_ptangle) |
|||
result.update({ |
|||
'FeaDfgt2pred': dfg2p, 'FeaDfpred2gt': dfp2g, |
|||
'FeaAnglegt2pred': fag2p, 'FeaAnglepred2gt': fap2g, |
|||
'FeaDf': (dfg2p + dfp2g) / 2.0, |
|||
'FeaAngle': (fag2p + fap2g) / 2.0 |
|||
}) |
|||
|
|||
# 保存缓存 |
|||
with open(stat_file, "wb") as f: |
|||
pickle.dump(result, f) |
|||
|
|||
return result |
|||
except Exception as e: |
|||
logger.error(f"Error processing {line}: {str(e)}") |
|||
return None |
|||
|
|||
def compute_all(): |
|||
"""计算所有模型的评估指标""" |
|||
try: |
|||
# 初始化结果字典 |
|||
results = [] |
|||
|
|||
# 读取模型列表 |
|||
with open(os.path.join(project_dir, 'data', args.name_list), 'r') as f: |
|||
lines = f.readlines() |
|||
print(lines) |
|||
# 处理每个模型 |
|||
for line in lines: |
|||
result = load_and_process_single_model(line, args.gt_path, args.pred_path, args) |
|||
if result: |
|||
results.append(result) |
|||
logger.info(result) |
|||
# 计算平均值 |
|||
mean_result = {'name': 'mean'} |
|||
for key in results[0].keys(): |
|||
if key != 'name': |
|||
mean_result[key] = sum(r[key] for r in results) / len(results) |
|||
results.append(mean_result) |
|||
|
|||
# 保存结果 |
|||
df = pd.DataFrame(results) |
|||
df.to_csv(args.csv_name, index=False) |
|||
|
|||
logger.info(f"Evaluation completed. Results saved to {os.path.abspath(args.csv_name)}") |
|||
|
|||
except Exception as e: |
|||
logger.error(f"Error in compute_all: {str(e)}") |
|||
raise |
|||
|
|||
if __name__ == '__main__': |
|||
compute_all() |
Loading…
Reference in new issue