Browse Source

Refactor evaluation.py to improve model evaluation process and error handling

- Introduced a new function `load_and_process_single_model` to encapsulate the logic for evaluating a single model, enhancing code readability and maintainability.
- Updated `compute_all` to utilize the new function, streamlining the overall evaluation workflow.
- Improved error handling with logging for missing files and exceptions during processing.
- Enhanced caching mechanism for computed results to avoid redundant calculations.
- Added detailed comments and documentation for better understanding of the evaluation process.
main
mckay 2 months ago
parent
commit
aa76cc950b
  1. 154
      code/evaluation/evaluation.py

154
code/evaluation/evaluation.py

@ -129,92 +129,114 @@ def distance_fea(gt_pa, pred_pa):
return dfg2p, dfp2g, fag2p, fap2g return dfg2p, dfp2g, fag2p, fap2g
def compute_all(): def load_and_process_single_model(line, gt_path, pred_mesh_path, args):
gt_path = args.gt_path """处理单个模型的评估
pred_mesh_path = args.pred_path Args:
namelst = args.name_list line (str): 模型名称
output_path = 'eval_results.csv' gt_path (str): 真值路径
pred_mesh_path (str): 预测网格路径
with open(os.path.join(project_dir, 'evaluation', namelst), 'r') as f: args: 参数配置
lines = f.readlines() Returns:
dict: 包含该模型所有评估指标的字典
d = {'name':[], 'CD':[], 'HD':[], 'HDgt2pred':[], 'HDpred2gt':[], 'AngleDiffMean':[], 'AngleDiffStd':[], 'FeaDfgt2pred':[], 'FeaDfpred2gt':[], 'FeaDf':[], 'FeaAnglegt2pred':[], 'FeaAnglepred2gt':[], 'FeaAngle':[]} """
try:
for line in lines:
line = line.strip()[:-4] line = line.strip()[:-4]
print(line) result = {'name': line}
# 加载点云数据
test_xyz = os.path.join(gt_path, line+'_50k.xyz') test_xyz = os.path.join(gt_path, line+'_50k.xyz')
ptnormal = np.loadtxt(test_xyz) ptnormal = np.loadtxt(test_xyz)
meshfile = os.path.join(pred_mesh_path, '{}_50k.ply'.format(line))
# 加载预测网格
meshfile = os.path.join(pred_mesh_path, '{}_50k.ply'.format(line))
if not os.path.exists(meshfile): if not os.path.exists(meshfile):
print('file not exists: ', meshfile) logger.warning(f'File not exists: {meshfile}')
f = open(meshfile + 'noexists', 'w') return None
f.close()
continue # 检查缓存
stat_file = meshfile + "_stat" stat_file = meshfile + "_stat"
if not args.regen and os.path.exists(stat_file) and os.path.getsize(stat_file) > 0: if not args.regen and os.path.exists(stat_file) and os.path.getsize(stat_file) > 0:
#load compuated ones with open(stat_file, 'rb') as f:
f = open(stat_file, 'rb') return pickle.load(f)
cur_dict = pickle.load(f)
for k in cur_dict:
d[k].append(cur_dict[k])
f.close()
continue
d['name'].append(line)
# 计算网格距离指标
mesh = trimesh.load(meshfile) mesh = trimesh.load(meshfile)
cd, hd, adm, ads, hd_pred2gt, hd_gt2pred = distance_p2mesh(
ptnormal[:,:3], ptnormal[:,3:], mesh)
cd, hd, adm, ads, hd_pred2gt, hd_gt2pred = distance_p2mesh(ptnormal[:,:3], ptnormal[:,3:], mesh) result.update({
'CD': cd, 'HD': hd, 'HDpred2gt': hd_pred2gt,
d['CD'].append(cd) 'HDgt2pred': hd_gt2pred, 'AngleDiffMean': adm,
d['HD'].append(hd) 'AngleDiffStd': ads
d['HDpred2gt'].append(hd_pred2gt) })
d['HDgt2pred'].append(hd_gt2pred)
d['AngleDiffMean'].append(adm)
d['AngleDiffStd'].append(ads)
# 计算特征点指标
gt_ptangle = np.loadtxt(os.path.join(gt_path, line + '_detectfea4e-3.ptangle')) gt_ptangle = np.loadtxt(os.path.join(gt_path, line + '_detectfea4e-3.ptangle'))
pred_ptangle_path = meshfile[:-4]+'_4e-3.ptangle' pred_ptangle_path = meshfile[:-4]+'_4e-3.ptangle'
if not os.path.exists(pred_ptangle_path) or args.regen: if not os.path.exists(pred_ptangle_path) or args.regen:
os.system('./evaluation/MeshFeatureSample/build/SimpleSample -i {} -o {} -s 4e-3'.format(meshfile, pred_ptangle_path)) os.system('./evaluation/MeshFeatureSample/build/SimpleSample -i {} -o {} -s 4e-3'.format(
meshfile, pred_ptangle_path))
pred_ptangle = np.loadtxt(pred_ptangle_path).reshape(-1,4) pred_ptangle = np.loadtxt(pred_ptangle_path).reshape(-1,4)
#for smooth case: if gt fea is empty, or pred fea is empty, then return 0 # 处理特征点结果
if len(gt_ptangle) == 0 or len(pred_ptangle) == 0: if len(gt_ptangle) == 0 or len(pred_ptangle) == 0:
d['FeaDfgt2pred'].append(0.0) result.update({
d['FeaDfpred2gt'].append(0.0) 'FeaDfgt2pred': 0.0, 'FeaDfpred2gt': 0.0,
d['FeaAnglegt2pred'].append(0.0) 'FeaAnglegt2pred': 0.0, 'FeaAnglepred2gt': 0.0,
d['FeaAnglepred2gt'].append(0.0) 'FeaDf': 0.0, 'FeaAngle': 0.0
d['FeaDf'].append(0.0) })
d['FeaAngle'].append(0.0)
else: else:
dfg2p, dfp2g, fag2p, fap2g = distance_fea(gt_ptangle, pred_ptangle) dfg2p, dfp2g, fag2p, fap2g = distance_fea(gt_ptangle, pred_ptangle)
d['FeaDfgt2pred'].append(dfg2p) result.update({
d['FeaDfpred2gt'].append(dfp2g) 'FeaDfgt2pred': dfg2p, 'FeaDfpred2gt': dfp2g,
d['FeaAnglegt2pred'].append(fag2p) 'FeaAnglegt2pred': fag2p, 'FeaAnglepred2gt': fap2g,
d['FeaAnglepred2gt'].append(fap2g) 'FeaDf': (dfg2p + dfp2g) / 2.0,
d['FeaDf'].append((dfg2p + dfp2g) / 2.0) 'FeaAngle': (fag2p + fap2g) / 2.0
d['FeaAngle'].append((fag2p + fap2g) / 2.0) })
cur_d = {} # 保存缓存
for k in d: with open(stat_file, "wb") as f:
cur_d[k] = d[k][-1] pickle.dump(result, f)
return result
except Exception as e:
logger.error(f"Error processing {line}: {str(e)}")
return None
f = open(stat_file,"wb") def compute_all():
pickle.dump(cur_d, f) """计算所有模型的评估指标"""
f.close() try:
# 初始化结果字典
results = []
d['name'].append('mean')
for key in d: # 读取模型列表
if key != 'name': with open(os.path.join(project_dir, 'evaluation', args.name_list), 'r') as f:
d[key].append(sum(d[key])/len(d[key])) lines = f.readlines()
df = pd.DataFrame(d, columns=['name', 'CD', 'HD', 'HDpred2gt', 'HDgt2pred', 'AngleDiffMean', 'AngleDiffStd','FeaDfgt2pred', 'FeaDfpred2gt', 'FeaDf', 'FeaAnglegt2pred', 'FeaAnglepred2gt', 'FeaAngle']) # 处理每个模型
for line in lines:
df.to_csv(output_path, index = False, header=True) result = load_and_process_single_model(line, args.gt_path, args.pred_path, args)
if result:
results.append(result)
# 计算平均值
mean_result = {'name': 'mean'}
for key in results[0].keys():
if key != 'name':
mean_result[key] = sum(r[key] for r in results) / len(results)
results.append(mean_result)
# 保存结果
df = pd.DataFrame(results)
df.to_csv('eval_results.csv', index=False)
logger.info(f"Evaluation completed. Results saved to {os.path.abspath('eval_results.csv')}")
except Exception as e:
logger.error(f"Error in compute_all: {str(e)}")
raise
if __name__ == '__main__': if __name__ == '__main__':
compute_all() compute_all()
Loading…
Cancel
Save