import plotly.graph_objs as go import plotly.offline as offline import torch import numpy as np from skimage import measure import os import utils.general as utils def get_threed_scatter_trace(points,caption = None,colorscale = None,color = None): if (type(points) == list): trace = [go.Scatter3d( x=p[0][:, 0], y=p[0][:, 1], z=p[0][:, 2], mode='markers', name=p[1], marker=dict( size=3, line=dict( width=2, ), opacity=0.9, colorscale=colorscale, showscale=True, color=color, ), text=caption) for p in points] else: trace = [go.Scatter3d( x=points[:,0], y=points[:,1], z=points[:,2], mode='markers', name='projection', marker=dict( size=3, line=dict( width=2, ), opacity=0.9, colorscale=colorscale, showscale=True, color=color, ), text=caption)] return trace def plot_threed_scatter(points,path,epoch,in_epoch): trace = get_threed_scatter_trace(points) layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False), yaxis=dict(range=[-2, 2], autorange=False), zaxis=dict(range=[-2, 2], autorange=False), aspectratio=dict(x=1, y=1, z=1))) fig1 = go.Figure(data=trace, layout=layout) filename = '{0}/scatter_iteration_{1}_{2}.html'.format(path, epoch, in_epoch) offline.plot(fig1, filename=filename, auto_open=False) def plot_surface(decoder,path,epoch, shapename,resolution,mc_value,is_uniform_grid,verbose,save_html,save_ply,overwrite, points=None, with_points=False, latent=None, connected=False, suffix = "all"): filename = '{0}/igr_{1}_{2}'.format(path, epoch, shapename) if (not os.path.exists(filename) or overwrite): if with_points: pnts_val = decoder(points) print ("pnts size: ", pnts_val.shape) # modified on 20200922 # pnts_val_all = decoder(points) # pnts_val = pnts_val_all[:,0] pnts_val = pnts_val.cpu() points = points.cpu() caption = ["decoder : {0}".format(val.item()) for val in pnts_val.squeeze()] trace_pnts = get_threed_scatter_trace(points[:,-3:],caption=caption) surface = get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform_grid,verbose,save_ply, connected) trace_surface = surface["mesh_trace"] layout = go.Layout(title= go.layout.Title(text=shapename), width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False), yaxis=dict(range=[-2, 2], autorange=False), zaxis=dict(range=[-2, 2], autorange=False), aspectratio=dict(x=1, y=1, z=1))) if (with_points): fig1 = go.Figure(data=trace_pnts + trace_surface, layout=layout) else: fig1 = go.Figure(data=trace_surface, layout=layout) if (save_html): offline.plot(fig1, filename=filename + suffix + '.html', auto_open=False) if (not surface['mesh_export'] is None): surface['mesh_export'].export(filename + suffix + '.ply', 'ply') return surface['mesh_export'] def get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform,verbose,save_ply, connected=False): trace = [] meshexport = None if (is_uniform): grid = get_grid_uniform(resolution) else: if not points is None: grid = get_grid(points[:,-3:],resolution) else: grid = get_grid(None, resolution) z = [] for i,pnts in enumerate(torch.split(grid['grid_points'],100000,dim=0)): if (verbose): print ('{0}'.format(i/(grid['grid_points'].shape[0] // 100000) * 100)) if (not latent is None): pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) z.append(decoder(pnts).detach().cpu().numpy()) # z.append(decoder(pnts)[:,0].detach().cpu().numpy()) z = np.concatenate(z,axis=0) if (not (np.min(z) > mc_value or np.max(z) < mc_value)): import trimesh z = z.astype(np.float64) verts, faces, normals, values = measure.marching_cubes_lewiner( volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], grid['xyz'][2].shape[0]).transpose([1, 0, 2]), level=mc_value, spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1])) verts = verts + np.array([grid['xyz'][0][0],grid['xyz'][1][0],grid['xyz'][2][0]]) if (save_ply): meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values) if connected: connected_comp = meshexport.split(only_watertight=False) max_area = 0 max_comp = None for comp in connected_comp: if comp.area > max_area: max_area = comp.area max_comp = comp meshexport = max_comp def tri_indices(simplices): return ([triplet[c] for triplet in simplices] for c in range(3)) I, J, K = tri_indices(faces) trace.append(go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], i=I, j=J, k=K, name='', color='orange', opacity=0.5)) #trace and export are the same return {"mesh_trace":trace, "mesh_export":meshexport} def plot_cuts_axis(points,decoder,latent,path,epoch,near_zero,axis,file_name_sep='/'): onedim_cut = np.linspace(-1.0, 1.0, 200) xx, yy = np.meshgrid(onedim_cut, onedim_cut) xx = xx.ravel() yy = yy.ravel() min_axis = points[:,axis].min(dim=0)[0].item() max_axis = points[:,axis].max(dim=0)[0].item() mask = np.zeros(3) mask[axis] = 1.0 if (axis == 0): position_cut = np.vstack(([np.zeros(xx.shape[0]), xx, yy])) elif (axis == 1): position_cut = np.vstack(([xx,np.zeros(xx.shape[0]), yy])) elif (axis == 2): position_cut = np.vstack(([xx, yy, np.zeros(xx.shape[0])])) position_cut = [position_cut + i*mask.reshape(-1, 1) for i in np.linspace(min_axis - 0.1, max_axis + 0.1, 50)] for index, pos in enumerate(position_cut): #fig = tools.make_subplots(rows=1, cols=1) field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float)) z = [] for i, pnts in enumerate(torch.split(field_input, 10000, dim=0)): if (not latent is None): pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) z.append(decoder(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) if (near_zero): if (np.min(z) < -1.0e-5): start = -0.1 else: start = 0.0 trace1 = go.Contour(x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=False, contours=dict( start=start, end=0.1, size=0.01 ) # ),colorbar = {'dtick': 0.05} ) else: # trace1 = go.Contour(x=onedim_cut, # y=onedim_cut, # z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), # name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), # autocontour=True, # ncontours=70 # # contours=dict( # # start=-0.001, # # end=0.001, # # size=0.00001 # # ) # # ),colorbar = {'dtick': 0.05} # ) trace1 = go.Contour(x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=False, # ncontours=70 contours=dict( start=-0.8, end=0.8, size=0.15 ) ) layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False), yaxis=dict(range=[-1, 1], autorange=False), aspectratio=dict(x=1, y=1)), title=dict(text='axis {0} = {1}'.format(axis,pos[axis, 0]))) # fig['layout']['xaxis2'].update(range=[-1, 1]) # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1) filename = '{0}{1}cutsaxis_{2}_{3}_{4}.html'.format(path,file_name_sep,axis, epoch, index) fig1 = go.Figure(data=[trace1], layout=layout) # offline.plot(fig1, filename=filename, auto_open=False) # fig1 = go.Figure(data=[trace1], layout=layout) fig1.write_image(filename.replace('.html', '.png')) def get_grid(points,resolution): eps = 0.1 input_min = torch.min(points, dim=0)[0].squeeze().cpu().numpy() input_max = torch.max(points, dim=0)[0].squeeze().cpu().numpy() bounding_box = input_max - input_min shortest_axis = np.argmin(bounding_box) if (shortest_axis == 0): x = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(x) - np.min(x) y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) elif (shortest_axis == 1): y = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(y) - np.min(y) x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) elif (shortest_axis == 2): z = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(z) - np.min(z) x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) xx, yy, zz = np.meshgrid(x, y, z) grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() return {"grid_points":grid_points, "shortest_axis_length":length, "xyz":[x,y,z], "shortest_axis_index":shortest_axis} def get_grid_uniform(resolution): x = np.linspace(-1.2,1.2, resolution) y = x z = x xx, yy, zz = np.meshgrid(x, y, z) grid_points = utils.to_cuda(torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)) return {"grid_points": grid_points, "shortest_axis_length": 2.4, "xyz": [x, y, z], "shortest_axis_index": 0}