Browse Source

ruff 清理包导入

final
mckay 2 months ago
parent
commit
0d0f6ecfca
  1. 3
      brep2sdf/data/data.py
  2. 7
      brep2sdf/data/pre_process.py
  3. 16
      brep2sdf/data/pre_process_by_mesh.py
  4. 29
      brep2sdf/data/sampler.py
  5. 5
      brep2sdf/data/utils.py
  6. 4
      brep2sdf/deep_sdf/workspace.py
  7. 4
      brep2sdf/evaluation.py
  8. 4
      brep2sdf/networks/decoder.py
  9. 6
      brep2sdf/networks/encoder.py
  10. 1
      brep2sdf/networks/feature_volume.py
  11. 3
      brep2sdf/networks/loss.py
  12. 1
      brep2sdf/networks/network.py
  13. 4
      brep2sdf/networks/octree.py
  14. 8
      brep2sdf/networks/patch_graph.py
  15. 2
      brep2sdf/scripts/convert/brep_to_mesh.py
  16. 8
      brep2sdf/scripts/diagnose.py
  17. 4
      brep2sdf/scripts/process_brep.py
  18. 5
      brep2sdf/scripts/process_furniture.py
  19. 5
      brep2sdf/scripts/read_npz.py
  20. 2
      brep2sdf/scripts/read_pkl.py
  21. 4
      brep2sdf/train.py

3
brep2sdf/data/data.py

@ -1,10 +1,7 @@
import os
import torch import torch
from torch.utils.data import Dataset
import numpy as np import numpy as np
import pickle import pickle
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config

7
brep2sdf/data/pre_process.py

@ -8,12 +8,7 @@ CAD模型处理脚本
import os import os
import pickle # 用于数据序列化 import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -28,7 +23,7 @@ from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFa
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒 from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config

16
brep2sdf/data/pre_process_by_mesh.py

@ -8,34 +8,24 @@ CAD模型处理脚本
import os import os
import pickle # 用于数据序列化 import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from scipy.spatial import cKDTree
import tempfile import tempfile
import trimesh import trimesh
from trimesh.proximity import ProximityQuery
# 导入OpenCASCADE相关库 # 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 from OCC.Core.TopExp import TopExp_Explorer # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具 from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换 from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码 from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.TopoDS import topods # 拓扑数据结构
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer from OCC.Core.StlAPI import StlAPI_Writer
from brep2sdf.data.sampler import sample_sdf_points_and_normals from brep2sdf.data.sampler import sample_sdf_points_and_normals
from brep2sdf.data.data import check_data_format from brep2sdf.data.data import check_data_format
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,batch_compute_normals
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config

29
brep2sdf/data/sampler.py

@ -6,37 +6,16 @@ CAD模型处理脚本
- 空间信息包围盒数据 - 空间信息包围盒数据
""" """
import os
import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
import tempfile
import trimesh import trimesh
from trimesh.proximity import ProximityQuery from trimesh.proximity import ProximityQuery
# 导入OpenCASCADE相关库 # 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals from brep2sdf.data.utils import batch_compute_normals
config = get_default_config() config = get_default_config()
@ -168,7 +147,7 @@ def sample_sdf_points_and_normals(
# 添加调试信息 # 添加调试信息
if i == 0: # 只打印第一个批次的统计信息 if i == 0: # 只打印第一个批次的统计信息
logger.debug(f"批次统计 (首批次):") logger.debug("批次统计 (首批次):")
logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]") logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]")
logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}") logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}")
logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]") logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]")
@ -196,7 +175,7 @@ def sample_sdf_points_and_normals(
# 验证法向量 # 验证法向量
normal_lengths = np.linalg.norm(sampled_normals, axis=1) normal_lengths = np.linalg.norm(sampled_normals, axis=1)
logger.debug(f"最终法向量统计:") logger.debug("最终法向量统计:")
logger.debug(f" 形状: {sampled_normals.shape}") logger.debug(f" 形状: {sampled_normals.shape}")
logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}") logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}")
logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]") logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]")
@ -226,7 +205,7 @@ def sample_sdf_points_and_normals(
# 添加SDF分布验证 # 添加SDF分布验证
final_sdf = combined_data[:, -1] final_sdf = combined_data[:, -1]
logger.debug(f"最终SDF分布验证:") logger.debug("最终SDF分布验证:")
logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}") logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}")
logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}") logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}")
logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}") logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}")

5
brep2sdf/data/utils.py

@ -1,15 +1,14 @@
# 导入OpenCASCADE相关库 # 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE # 拓扑类型定义
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒 from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构
import numpy as np import numpy as np
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
import trimesh
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger

4
brep2sdf/deep_sdf/workspace.py

@ -28,7 +28,7 @@ def load_experiment_specifications(experiment_directory):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise Exception( raise Exception(
"The experiment directory ({}) does not include specifications file " "The experiment directory ({}) does not include specifications file "
+ '"specs.json"'.format(experiment_directory) + '"specs.json"'
) )
return json.load(open(filename)) return json.load(open(filename))
@ -86,7 +86,7 @@ def load_latent_vectors(experiment_directory, checkpoint):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise Exception( raise Exception(
"The experiment directory ({}) does not include a latent code file" "The experiment directory ({}) does not include a latent code file"
+ " for checkpoint '{}'".format(experiment_directory, checkpoint) + " for checkpoint '{}'".format(experiment_directory, )
) )
data = torch.load(filename) data = torch.load(filename)

4
brep2sdf/evaluation.py

@ -1,16 +1,12 @@
import os import os
import sys
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
# 导入日志系统 # 导入日志系统
from brep2sdf.utils.logger import logger
import numpy as np import numpy as np
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
from scipy.spatial.distance import directed_hausdorff
import trimesh import trimesh
import pandas as pd import pandas as pd
import csv
import math import math
import pickle import pickle

4
brep2sdf/networks/decoder.py

@ -1,11 +1,7 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Dict, Optional, Tuple, Union
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, def __init__(self,

6
brep2sdf/networks/encoder.py

@ -1,13 +1,7 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Union
from .octree import OctreeNode from .octree import OctreeNode
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
import numpy as np
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, octree: OctreeNode, feature_dim: int = 32): def __init__(self, octree: OctreeNode, feature_dim: int = 32):

1
brep2sdf/networks/feature_volume.py

@ -2,7 +2,6 @@ from typing import Tuple, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
class PatchFeatureVolume(nn.Module): class PatchFeatureVolume(nn.Module):
def __init__(self, bbox:np, resolution=64, feature_dim=64): def __init__(self, bbox:np, resolution=64, feature_dim=64):

3
brep2sdf/networks/loss.py

@ -1,8 +1,5 @@
import torch import torch
import torch.nn as nn
from .network import gradient from .network import gradient
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
class LossManager: class LossManager:

1
brep2sdf/networks/network.py

@ -1,6 +1,5 @@
from typing import Tuple
''' '''
class GridNet: class GridNet:

4
brep2sdf/networks/octree.py

@ -1,11 +1,9 @@
from typing import Tuple, List, cast, Dict, Any from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger
from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.patch_graph import PatchGraph

8
brep2sdf/networks/patch_graph.py

@ -1,13 +1,7 @@
from typing import Tuple, Optional from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopoDS import TopoDS_Edge, TopoDS_Face, topods_Edge, topods_Face
from OCC.Core.BRep import BRep_Tool
from OCC.Core.GeomLProp import GeomLProp_SLProps
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
class PatchGraph(nn.Module): class PatchGraph(nn.Module):
def __init__(self, num_patches: int, device: torch.device = None): def __init__(self, num_patches: int, device: torch.device = None):

2
brep2sdf/scripts/convert/brep_to_mesh.py

@ -1,7 +1,5 @@
import os import os
import sys
import pickle import pickle
import argparse
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
from convert_utils import * from convert_utils import *

8
brep2sdf/scripts/diagnose.py

@ -46,7 +46,7 @@ class ModelDiagnostics:
logger.info("\n查询点编码:") logger.info("\n查询点编码:")
logger.info(f" 输入形状: {batch['query_points'].shape}") logger.info(f" 输入形状: {batch['query_points'].shape}")
logger.info(f" 输出形状: {query_features.shape}") logger.info(f" 输出形状: {query_features.shape}")
logger.info(f" 特征统计:") logger.info(" 特征统计:")
logger.info(f" 均值: {query_features.mean():.4f}") logger.info(f" 均值: {query_features.mean():.4f}")
logger.info(f" 标准差: {query_features.std():.4f}") logger.info(f" 标准差: {query_features.std():.4f}")
logger.info(f" 最大值: {query_features.max():.4f}") logger.info(f" 最大值: {query_features.max():.4f}")
@ -63,7 +63,7 @@ class ModelDiagnostics:
) )
logger.info("\nB-rep特征编码:") logger.info("\nB-rep特征编码:")
logger.info(f" 输出形状: {brep_features.shape}") logger.info(f" 输出形状: {brep_features.shape}")
logger.info(f" 特征统计:") logger.info(" 特征统计:")
logger.info(f" 均值: {brep_features.mean():.4f}") logger.info(f" 均值: {brep_features.mean():.4f}")
logger.info(f" 标准差: {brep_features.std():.4f}") logger.info(f" 标准差: {brep_features.std():.4f}")
@ -71,7 +71,7 @@ class ModelDiagnostics:
global_features = brep_features.mean(dim=1) global_features = brep_features.mean(dim=1)
logger.info("\n全局特征:") logger.info("\n全局特征:")
logger.info(f" 形状: {global_features.shape}") logger.info(f" 形状: {global_features.shape}")
logger.info(f" 统计:") logger.info(" 统计:")
logger.info(f" 均值: {global_features.mean():.4f}") logger.info(f" 均值: {global_features.mean():.4f}")
logger.info(f" 标准差: {global_features.std():.4f}") logger.info(f" 标准差: {global_features.std():.4f}")
@ -79,7 +79,7 @@ class ModelDiagnostics:
sdf = self.model(**batch) sdf = self.model(**batch)
logger.info("\nSDF预测:") logger.info("\nSDF预测:")
logger.info(f" 形状: {sdf.shape}") logger.info(f" 形状: {sdf.shape}")
logger.info(f" 统计:") logger.info(" 统计:")
logger.info(f" 均值: {sdf.mean():.4f}") logger.info(f" 均值: {sdf.mean():.4f}")
logger.info(f" 标准差: {sdf.std():.4f}") logger.info(f" 标准差: {sdf.std():.4f}")
logger.info(f" 最大值: {sdf.max():.4f}") logger.info(f" 最大值: {sdf.max():.4f}")

4
brep2sdf/scripts/process_brep.py

@ -8,11 +8,9 @@ CAD模型处理脚本
import os import os
import pickle # 用于数据序列化 import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np import numpy as np
from tqdm import tqdm # 进度条显示 from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime from datetime import datetime
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -28,7 +26,7 @@ from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFa
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒 from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config

5
brep2sdf/scripts/process_furniture.py

@ -1,15 +1,12 @@
import os import os
import time
import glob import glob
import trimesh import trimesh
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import mesh2sdf import mesh2sdf
import skimage.measure
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
import logging import logging
from typing import Tuple, Optional from typing import Optional
from OCC.Core.STEPControl import STEPControl_Reader from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh

5
brep2sdf/scripts/read_npz.py

@ -2,7 +2,6 @@ import numpy as np
import argparse import argparse
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def view_npz_file(file_path: str, save_plot: bool = False): def view_npz_file(file_path: str, save_plot: bool = False):
"""查看指定的npz文件内容 """查看指定的npz文件内容
@ -16,9 +15,9 @@ def view_npz_file(file_path: str, save_plot: bool = False):
data = np.load(file_path) data = np.load(file_path)
# 打印基本信息 # 打印基本信息
logger.info(f"\n=== NPZ文件内容分析 ===") logger.info("\n=== NPZ文件内容分析 ===")
logger.info(f"文件路径: {file_path}") logger.info(f"文件路径: {file_path}")
logger.info(f"\n包含的数组:") logger.info("\n包含的数组:")
# 分析每个数组 # 分析每个数组
for key in data.files: for key in data.files:

2
brep2sdf/scripts/read_pkl.py

@ -20,7 +20,7 @@ def inspect_data(pkl_file):
print("-" * 30) print("-" * 30)
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
print(f"类型: numpy.ndarray") print("类型: numpy.ndarray")
print(f"形状: {value.shape}") print(f"形状: {value.shape}")
print(f"数据类型: {value.dtype}") print(f"数据类型: {value.dtype}")
if value.size > 0: if value.size > 0:

4
brep2sdf/train.py

@ -1,6 +1,4 @@
import torch import torch
from torch.serialization import add_safe_globals
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch.optim as optim import torch.optim as optim
import time import time
import os import os
@ -8,7 +6,7 @@ import numpy as np
import argparse import argparse
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file, prepare_sdf_data, print_data_distribution, check_tensor from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.octree import OctreeNode

Loading…
Cancel
Save