You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.0 KiB
119 lines
4.0 KiB
"""
|
|
Sampling in spatial domain
|
|
|
|
collocation points
|
|
boundary points
|
|
|
|
For sure, lots of people will work on how to use different sampling grid
|
|
to train fully-connected networks.
|
|
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
from .lhs import lhs
|
|
|
|
|
|
class SampleSpatial2d(object):
|
|
"""Uniform grid
|
|
(y, x)
|
|
|
|
h - height, or y axis
|
|
w - width, x axis
|
|
|
|
default output [0, 1] from [0, ngrid_h - 1], [0, ngrid_w - 1]
|
|
"""
|
|
def __init__(self, ngrid_h, ngrid_w):
|
|
self.ngrid_h = int(ngrid_h)
|
|
self.ngrid_w = int(ngrid_w)
|
|
self.n_grids = self.ngrid_h * self.ngrid_w
|
|
self.refactor = torch.FloatTensor(np.array([[ngrid_h-1, ngrid_w-1]]))
|
|
self.coordinates = self._coordinates()
|
|
self.coordinates_no_boundary = self._coordinates_no_boundary()
|
|
|
|
def _coordinates(self):
|
|
# super wired torch.meshgrid
|
|
grid_x, grid_y = np.meshgrid(np.arange(self.ngrid_w), np.arange(self.ngrid_h))
|
|
points = np.stack((grid_y.flatten(), grid_x.flatten()), 1)
|
|
return torch.FloatTensor(points)
|
|
|
|
def _coordinates_no_boundary(self):
|
|
grid_x, grid_y = np.meshgrid(np.arange(self.ngrid_w), np.arange(self.ngrid_h))
|
|
points = np.stack((grid_y[1:-1, 1:-1].flatten(), grid_x[1:-1, 1:-1].flatten()), 1)
|
|
return torch.FloatTensor(points)
|
|
|
|
|
|
def _sample2d(self, on_grid, n_samples=None, no_boundary=False):
|
|
if n_samples is None:
|
|
n_samples = self.n_grids
|
|
if on_grid:
|
|
if no_boundary:
|
|
points = self.coordinates_no_boundary.to(torch.float32) / self.refactor
|
|
else:
|
|
points = self.coordinates.to(torch.float32) / self.refactor
|
|
if n_samples < len(points):
|
|
points = points[torch.randperm(self.n_grids)[:n_samples]]
|
|
else:
|
|
print('n_samples is greater than grid size, set n_samples '\
|
|
'equals to grid size')
|
|
else:
|
|
points = torch.FloatTensor(lhs(2, n_samples))
|
|
|
|
return points
|
|
|
|
|
|
def _sample1d(self, horizontal, on_grid, n_samples=None):
|
|
"""
|
|
if on_grid is on, n_sampels is ignored if it is larger than ngrid.
|
|
"""
|
|
ngrid = self.ngrid_h if horizontal else self.ngrid_w
|
|
if n_samples is None:
|
|
n_samples = ngrid
|
|
if on_grid:
|
|
points = (torch.arange(float(ngrid)) / (ngrid-1))
|
|
if n_samples <= len(points):
|
|
points = points[torch.randperm(ngrid)[:n_samples]]
|
|
else:
|
|
print('n_samples is greater than grid size, set n_samples '\
|
|
'equals to grid size')
|
|
else:
|
|
points = torch.rand(n_samples)
|
|
return points
|
|
|
|
def left(self, on_grid=True, n_samples=None):
|
|
points = self._sample1d(horizontal=True, on_grid=on_grid, n_samples=n_samples)
|
|
return torch.stack((points, torch.zeros_like(points)), 1)
|
|
|
|
def right(self, on_grid=True, n_samples=None):
|
|
points = self._sample1d(horizontal=True, on_grid=on_grid, n_samples=n_samples)
|
|
return torch.stack((points, torch.ones_like(points)), 1)
|
|
|
|
def top(self, on_grid=True, n_samples=None):
|
|
points = self._sample1d(horizontal=False, on_grid=on_grid, n_samples=n_samples)
|
|
return torch.stack((torch.zeros_like(points), points), 1)
|
|
|
|
def bottom(self, on_grid=True, n_samples=None):
|
|
points = self._sample1d(horizontal=False, on_grid=on_grid, n_samples=n_samples)
|
|
return torch.stack((torch.ones_like(points), points), 1)
|
|
|
|
def colloc(self, on_grid=True, n_samples=None, no_boundary=False):
|
|
return self._sample2d(on_grid, n_samples, no_boundary)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
ngrid_h = 10
|
|
ngrid_w = 10
|
|
|
|
sampler = SampleSpatial2d(ngrid_h, ngrid_w)
|
|
# print(sampler.refactor)
|
|
# print(sampler.refactor.shape)
|
|
|
|
# points = sampler.lhs(n_samples=1000, on_grid=True)
|
|
# print(points)
|
|
|
|
points = sampler.right(on_grid=True, n_samples=12)
|
|
# points = sampler.colloc(on_grid=False, n_samples=99, no_boundary=False)
|
|
print(points * sampler.refactor)
|
|
print(points.shape)
|
|
|
|
|
|
|