import numpy as np

def Ms_u_reshape(u_data, coarse_nelx, coarse_nely, m):
    nelx=coarse_nelx*m
    nely=coarse_nely*m
    
    u_data = u_data.reshape(coarse_nelx,coarse_nely,m+1,m+1,2)
    u_data = u_data.swapaxes(1,2).reshape(coarse_nelx*(m+1), coarse_nely*(m+1), 2)  
    
    idx_x=np.arange(coarse_nelx*(m+1))[::m+1]
    idx_x=np.delete(idx_x,0)
    idx_x=np.delete(np.arange(coarse_nelx*(m+1)),idx_x)
    idx_y=np.arange(coarse_nely*(m+1))[::m+1]
    idx_y=np.delete(idx_y,0)
    idx_y=np.delete(np.arange(coarse_nely*(m+1)),idx_y)

    return u_data[idx_x.reshape(nelx+1,1),idx_y.reshape(1,nely+1)]


if __name__=='__main__':
    pred=np.load('results/pred.npy')
    u=np.load('datasets/train/180_60/u/mod2.npy')
    
    print(u.shape)
    print(pred.shape)

    recv_u = Ms_u_reshape(pred, 36, 12, 5)
    print(recv_u-u.reshape(181,61,2))