You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
12 KiB
290 lines
12 KiB
import plotly.graph_objs as go
|
|
import plotly.offline as offline
|
|
import torch
|
|
import numpy as np
|
|
from skimage import measure
|
|
import os
|
|
import utils.general as utils
|
|
|
|
def get_threed_scatter_trace(points,caption = None,colorscale = None,color = None):
|
|
|
|
if (type(points) == list):
|
|
trace = [go.Scatter3d(
|
|
x=p[0][:, 0],
|
|
y=p[0][:, 1],
|
|
z=p[0][:, 2],
|
|
mode='markers',
|
|
name=p[1],
|
|
marker=dict(
|
|
size=3,
|
|
line=dict(
|
|
width=2,
|
|
),
|
|
opacity=0.9,
|
|
colorscale=colorscale,
|
|
showscale=True,
|
|
color=color,
|
|
), text=caption) for p in points]
|
|
|
|
else:
|
|
|
|
trace = [go.Scatter3d(
|
|
x=points[:,0],
|
|
y=points[:,1],
|
|
z=points[:,2],
|
|
mode='markers',
|
|
name='projection',
|
|
marker=dict(
|
|
size=3,
|
|
line=dict(
|
|
width=2,
|
|
),
|
|
opacity=0.9,
|
|
colorscale=colorscale,
|
|
showscale=True,
|
|
color=color,
|
|
), text=caption)]
|
|
|
|
return trace
|
|
|
|
def plot_threed_scatter(points,path,epoch,in_epoch):
|
|
trace = get_threed_scatter_trace(points)
|
|
layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False),
|
|
yaxis=dict(range=[-2, 2], autorange=False),
|
|
zaxis=dict(range=[-2, 2], autorange=False),
|
|
aspectratio=dict(x=1, y=1, z=1)))
|
|
|
|
fig1 = go.Figure(data=trace, layout=layout)
|
|
|
|
filename = '{0}/scatter_iteration_{1}_{2}.html'.format(path, epoch, in_epoch)
|
|
offline.plot(fig1, filename=filename, auto_open=False)
|
|
|
|
def plot_surface(decoder,path,epoch, shapename,resolution,mc_value,is_uniform_grid,verbose,save_html,save_ply,overwrite, points=None, with_points=False, latent=None, connected=False, suffix = "all"):
|
|
|
|
filename = '{0}/igr_{1}_{2}'.format(path, epoch, shapename)
|
|
|
|
if (not os.path.exists(filename) or overwrite):
|
|
|
|
if with_points:
|
|
pnts_val = decoder(points)
|
|
print ("pnts size: ", pnts_val.shape)
|
|
# modified on 20200922
|
|
# pnts_val_all = decoder(points)
|
|
# pnts_val = pnts_val_all[:,0]
|
|
pnts_val = pnts_val.cpu()
|
|
points = points.cpu()
|
|
caption = ["decoder : {0}".format(val.item()) for val in pnts_val.squeeze()]
|
|
trace_pnts = get_threed_scatter_trace(points[:,-3:],caption=caption)
|
|
|
|
surface = get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform_grid,verbose,save_ply, connected)
|
|
trace_surface = surface["mesh_trace"]
|
|
|
|
layout = go.Layout(title= go.layout.Title(text=shapename), width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False),
|
|
yaxis=dict(range=[-2, 2], autorange=False),
|
|
zaxis=dict(range=[-2, 2], autorange=False),
|
|
aspectratio=dict(x=1, y=1, z=1)))
|
|
if (with_points):
|
|
fig1 = go.Figure(data=trace_pnts + trace_surface, layout=layout)
|
|
else:
|
|
fig1 = go.Figure(data=trace_surface, layout=layout)
|
|
|
|
|
|
if (save_html):
|
|
offline.plot(fig1, filename=filename + suffix + '.html', auto_open=False)
|
|
if (not surface['mesh_export'] is None):
|
|
surface['mesh_export'].export(filename + suffix + '.ply', 'ply')
|
|
return surface['mesh_export']
|
|
|
|
def get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform,verbose,save_ply, connected=False):
|
|
|
|
trace = []
|
|
meshexport = None
|
|
|
|
if (is_uniform):
|
|
grid = get_grid_uniform(resolution)
|
|
else:
|
|
if not points is None:
|
|
grid = get_grid(points[:,-3:],resolution)
|
|
else:
|
|
grid = get_grid(None, resolution)
|
|
|
|
z = []
|
|
|
|
for i,pnts in enumerate(torch.split(grid['grid_points'],100000,dim=0)):
|
|
if (verbose):
|
|
print ('{0}'.format(i/(grid['grid_points'].shape[0] // 100000) * 100))
|
|
|
|
if (not latent is None):
|
|
pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1)
|
|
z.append(decoder(pnts).detach().cpu().numpy())
|
|
# z.append(decoder(pnts)[:,0].detach().cpu().numpy())
|
|
z = np.concatenate(z,axis=0)
|
|
|
|
if (not (np.min(z) > mc_value or np.max(z) < mc_value)):
|
|
|
|
import trimesh
|
|
z = z.astype(np.float64)
|
|
|
|
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
|
volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
|
|
grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
|
|
level=mc_value,
|
|
spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
|
|
grid['xyz'][0][2] - grid['xyz'][0][1],
|
|
grid['xyz'][0][2] - grid['xyz'][0][1]))
|
|
|
|
verts = verts + np.array([grid['xyz'][0][0],grid['xyz'][1][0],grid['xyz'][2][0]])
|
|
if (save_ply):
|
|
meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)
|
|
if connected:
|
|
connected_comp = meshexport.split(only_watertight=False)
|
|
max_area = 0
|
|
max_comp = None
|
|
for comp in connected_comp:
|
|
if comp.area > max_area:
|
|
max_area = comp.area
|
|
max_comp = comp
|
|
meshexport = max_comp
|
|
|
|
def tri_indices(simplices):
|
|
return ([triplet[c] for triplet in simplices] for c in range(3))
|
|
|
|
I, J, K = tri_indices(faces)
|
|
|
|
trace.append(go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
|
|
i=I, j=J, k=K, name='',
|
|
color='orange', opacity=0.5))
|
|
#trace and export are the same
|
|
return {"mesh_trace":trace,
|
|
"mesh_export":meshexport}
|
|
|
|
def plot_cuts_axis(points,decoder,latent,path,epoch,near_zero,axis,file_name_sep='/'):
|
|
onedim_cut = np.linspace(-1.0, 1.0, 200)
|
|
xx, yy = np.meshgrid(onedim_cut, onedim_cut)
|
|
xx = xx.ravel()
|
|
yy = yy.ravel()
|
|
min_axis = points[:,axis].min(dim=0)[0].item()
|
|
max_axis = points[:,axis].max(dim=0)[0].item()
|
|
mask = np.zeros(3)
|
|
mask[axis] = 1.0
|
|
if (axis == 0):
|
|
position_cut = np.vstack(([np.zeros(xx.shape[0]), xx, yy]))
|
|
elif (axis == 1):
|
|
position_cut = np.vstack(([xx,np.zeros(xx.shape[0]), yy]))
|
|
elif (axis == 2):
|
|
position_cut = np.vstack(([xx, yy, np.zeros(xx.shape[0])]))
|
|
position_cut = [position_cut + i*mask.reshape(-1, 1) for i in np.linspace(min_axis - 0.1, max_axis + 0.1, 50)]
|
|
for index, pos in enumerate(position_cut):
|
|
#fig = tools.make_subplots(rows=1, cols=1)
|
|
|
|
field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float))
|
|
z = []
|
|
for i, pnts in enumerate(torch.split(field_input, 10000, dim=0)):
|
|
if (not latent is None):
|
|
pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1)
|
|
z.append(decoder(pnts).detach().cpu().numpy())
|
|
z = np.concatenate(z, axis=0)
|
|
|
|
if (near_zero):
|
|
if (np.min(z) < -1.0e-5):
|
|
start = -0.1
|
|
else:
|
|
start = 0.0
|
|
trace1 = go.Contour(x=onedim_cut,
|
|
y=onedim_cut,
|
|
z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
|
|
name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8),
|
|
autocontour=False,
|
|
contours=dict(
|
|
start=start,
|
|
end=0.1,
|
|
size=0.01
|
|
)
|
|
# ),colorbar = {'dtick': 0.05}
|
|
)
|
|
else:
|
|
# trace1 = go.Contour(x=onedim_cut,
|
|
# y=onedim_cut,
|
|
# z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
|
|
# name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8),
|
|
# autocontour=True,
|
|
# ncontours=70
|
|
# # contours=dict(
|
|
# # start=-0.001,
|
|
# # end=0.001,
|
|
# # size=0.00001
|
|
# # )
|
|
# # ),colorbar = {'dtick': 0.05}
|
|
# )
|
|
|
|
trace1 = go.Contour(x=onedim_cut,
|
|
y=onedim_cut,
|
|
z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
|
|
name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8),
|
|
autocontour=False,
|
|
# ncontours=70
|
|
contours=dict(
|
|
start=-0.8,
|
|
end=0.8,
|
|
size=0.15
|
|
)
|
|
)
|
|
|
|
layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False),
|
|
yaxis=dict(range=[-1, 1], autorange=False),
|
|
aspectratio=dict(x=1, y=1)),
|
|
title=dict(text='axis {0} = {1}'.format(axis,pos[axis, 0])))
|
|
# fig['layout']['xaxis2'].update(range=[-1, 1])
|
|
# fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1)
|
|
|
|
filename = '{0}{1}cutsaxis_{2}_{3}_{4}.html'.format(path,file_name_sep,axis, epoch, index)
|
|
fig1 = go.Figure(data=[trace1], layout=layout)
|
|
# offline.plot(fig1, filename=filename, auto_open=False)
|
|
|
|
# fig1 = go.Figure(data=[trace1], layout=layout)
|
|
fig1.write_image(filename.replace('.html', '.png'))
|
|
|
|
def get_grid(points,resolution):
|
|
eps = 0.1
|
|
input_min = torch.min(points, dim=0)[0].squeeze().cpu().numpy()
|
|
input_max = torch.max(points, dim=0)[0].squeeze().cpu().numpy()
|
|
bounding_box = input_max - input_min
|
|
shortest_axis = np.argmin(bounding_box)
|
|
if (shortest_axis == 0):
|
|
x = np.linspace(input_min[shortest_axis] - eps,
|
|
input_max[shortest_axis] + eps, resolution)
|
|
length = np.max(x) - np.min(x)
|
|
y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
|
|
z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
|
|
elif (shortest_axis == 1):
|
|
y = np.linspace(input_min[shortest_axis] - eps,
|
|
input_max[shortest_axis] + eps, resolution)
|
|
length = np.max(y) - np.min(y)
|
|
x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
|
|
z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
|
|
elif (shortest_axis == 2):
|
|
z = np.linspace(input_min[shortest_axis] - eps,
|
|
input_max[shortest_axis] + eps, resolution)
|
|
length = np.max(z) - np.min(z)
|
|
x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
|
|
y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
|
|
|
|
xx, yy, zz = np.meshgrid(x, y, z)
|
|
grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
|
|
return {"grid_points":grid_points,
|
|
"shortest_axis_length":length,
|
|
"xyz":[x,y,z],
|
|
"shortest_axis_index":shortest_axis}
|
|
|
|
def get_grid_uniform(resolution):
|
|
x = np.linspace(-1.2,1.2, resolution)
|
|
y = x
|
|
z = x
|
|
|
|
xx, yy, zz = np.meshgrid(x, y, z)
|
|
grid_points = utils.to_cuda(torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float))
|
|
|
|
return {"grid_points": grid_points,
|
|
"shortest_axis_length": 2.4,
|
|
"xyz": [x, y, z],
|
|
"shortest_axis_index": 0}
|