You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

85 lines
2.2 KiB

# coding=utf-8
import torch
from models.codec_us import *
#from const import *
from p2voxel import *
import numpy
def getDH(x):
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
model = convNN3d(1,18,8).to(device)
ckpt_file = "./model_epoch3000.pth"
checkpoint = torch.load(ckpt_file)
state_dict = {}
for k, v in checkpoint['model_state_dict'].items():
if k in model.state_dict().keys():
state_dict[k] = v
model.load_state_dict(state_dict)
# start_epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# print(start_epoch, loss)
#### 从p2voxel代码里生成voxel
mtype = 387 - 1
in_parameter_names, out_parameter_names = showPara(mtype)
#print('in_parameters', in_parameter_names)
#print('out_parameters', out_parameter_names)
in_parameter_values = x[:7]
out_parameter_values = x[7:]
parameters = setPara(mtype, in_parameter_values, out_parameter_values)
voxel = p2voxel(mtype, parameters, resolution=39)
#drawVoxel_3(voxel)
#print (voxel.shape)
voxel = torch.from_numpy(voxel).unsqueeze(0).unsqueeze(0).to(device)
###
### 随机的输入
# voxel = torch.rand((1,1,40,40,40)).to(device)
### 使用模型计算出DH流程
U = model(voxel)
batch_size = voxel.shape[0]
ref18 = U.contiguous().view(batch_size, 18, -1)
map0 = ref18[:, 0::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
map1 = ref18[:, 1::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
map2 = ref18[:, 2::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
map3 = ref18[:, 3::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
map4 = ref18[:, 4::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
map5 = ref18[:, 5::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
output_map = torch.cat([map0, map1, map2, map3, map4, map5], 2)
output_U = output_map[:, edofMat, :]
DHs = disp2DH(voxel, output_U, D00, intB2, h)
#print(DHs)
return DHs.cpu().detach().numpy().reshape(6,6)
def test():
default = 0.1
x=[default]*10
DH=getDH(x)
#print (DH)
#test()