You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.2 KiB
85 lines
2.2 KiB
3 years ago
|
# coding=utf-8
|
||
|
import torch
|
||
|
from models.codec_us import *
|
||
|
#from const import *
|
||
|
|
||
|
from p2voxel import *
|
||
|
import numpy
|
||
|
|
||
|
def getDH(x):
|
||
|
torch.backends.cudnn.benchmark = True
|
||
|
|
||
|
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
|
||
|
model = convNN3d(1,18,8).to(device)
|
||
|
|
||
|
ckpt_file = "./model_epoch3000.pth"
|
||
|
|
||
|
checkpoint = torch.load(ckpt_file)
|
||
|
|
||
|
state_dict = {}
|
||
|
for k, v in checkpoint['model_state_dict'].items():
|
||
|
if k in model.state_dict().keys():
|
||
|
state_dict[k] = v
|
||
|
model.load_state_dict(state_dict)
|
||
|
|
||
|
# start_epoch = checkpoint['epoch']
|
||
|
# loss = checkpoint['loss']
|
||
|
# print(start_epoch, loss)
|
||
|
|
||
|
|
||
|
#### 从p2voxel代码里生成voxel
|
||
|
mtype = 387 - 1
|
||
|
|
||
|
in_parameter_names, out_parameter_names = showPara(mtype)
|
||
|
#print('in_parameters', in_parameter_names)
|
||
|
#print('out_parameters', out_parameter_names)
|
||
|
|
||
|
in_parameter_values = x[:7]
|
||
|
|
||
|
out_parameter_values = x[7:]
|
||
|
|
||
|
parameters = setPara(mtype, in_parameter_values, out_parameter_values)
|
||
|
|
||
|
voxel = p2voxel(mtype, parameters, resolution=39)
|
||
|
#drawVoxel_3(voxel)
|
||
|
#print (voxel.shape)
|
||
|
|
||
|
voxel = torch.from_numpy(voxel).unsqueeze(0).unsqueeze(0).to(device)
|
||
|
|
||
|
###
|
||
|
|
||
|
### 随机的输入
|
||
|
# voxel = torch.rand((1,1,40,40,40)).to(device)
|
||
|
|
||
|
### 使用模型计算出DH流程
|
||
|
|
||
|
U = model(voxel)
|
||
|
|
||
|
batch_size = voxel.shape[0]
|
||
|
ref18 = U.contiguous().view(batch_size, 18, -1)
|
||
|
map0 = ref18[:, 0::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
map1 = ref18[:, 1::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
map2 = ref18[:, 2::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
map3 = ref18[:, 3::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
map4 = ref18[:, 4::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
map5 = ref18[:, 5::6].permute((0, 2, 1)).contiguous().view(batch_size, -1, 1)
|
||
|
output_map = torch.cat([map0, map1, map2, map3, map4, map5], 2)
|
||
|
output_U = output_map[:, edofMat, :]
|
||
|
|
||
|
DHs = disp2DH(voxel, output_U, D00, intB2, h)
|
||
|
|
||
|
#print(DHs)
|
||
|
|
||
|
return DHs.cpu().detach().numpy().reshape(6,6)
|
||
|
|
||
|
def test():
|
||
|
default = 0.1
|
||
|
x=[default]*10
|
||
|
DH=getDH(x)
|
||
|
#print (DH)
|
||
|
|
||
|
|
||
|
|
||
|
#test()
|