You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
116 lines
5.4 KiB
116 lines
5.4 KiB
import argparse
|
|
import os
|
|
import time
|
|
from utils import utils
|
|
import torch
|
|
|
|
|
|
class BaseOptions():
|
|
"""This class defines options used during both training and test time.
|
|
|
|
It also implements several helper functions such as parsing, printing, and saving the options.
|
|
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Reset the class; indicates the class hasn't been initailized"""
|
|
self.initialized = False
|
|
|
|
def initialize(self, parser):
|
|
"""Define the common options that are used in both training and test."""
|
|
# basic parameters
|
|
parser.add_argument('--dataroot', type=str, default='./datasets', help='root path to datasets')
|
|
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
|
parser.add_argument('--gpu_id', type=str, default='0', help='gpu ids: e.g. 0, 1, ... . use -1 for CPU')
|
|
parser.add_argument('--device', type=str, default='cuda:0', help='generate device with gpu_id and usable user device')
|
|
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
|
parser.add_argument('--ms_ratio', type=int, default=5, help='multiscale ratio')
|
|
parser.add_argument('--results_dir', type=str, default='./results', help='saves results here.')
|
|
|
|
# model parameters
|
|
parser.add_argument('--model', type=str, default='ANN', help='chooses which model to use. [ANN | CNN | AutoEncoder]')
|
|
# dataset parameters
|
|
# parser.add_argument('--resolution', type=str, default='180_60', help='data resolution. nelx_nely here')
|
|
parser.add_argument('--nelx', type=int, default=180, help='num of elements on x-axis')
|
|
parser.add_argument('--nely', type=int, default=60, help='num of elements on y-axis')
|
|
parser.add_argument('--nelz', type=int, default=0, help='num of elements on z-axis')
|
|
parser.add_argument('--dimension', type=int, default=2, help='dimension of dataset models')
|
|
parser.add_argument('--is_standard', type=bool, default=True, help='whether need standardization or not')
|
|
|
|
# additional parameters
|
|
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
|
parser.add_argument('--load_iter', type=int, default=0, help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
|
|
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
|
parser.add_argument('--suffix', type=str, default='', help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
|
|
|
# identify initializiation timing
|
|
self.initialized = True
|
|
return parser
|
|
|
|
def gather_options(self):
|
|
"""Initialize our parser with basic options(only once).
|
|
Add additional model-specific and dataset-specific options.
|
|
These options are defined in the <modify_commandline_options> function
|
|
in model and dataset classes.
|
|
"""
|
|
if not self.initialized: # check if it has been initialized
|
|
parser = argparse.ArgumentParser() # customize help formatting with <formatter_class>
|
|
parser = self.initialize(parser)
|
|
|
|
# get the basic options
|
|
opt, _ = parser.parse_known_args()
|
|
|
|
# save and return the parser
|
|
self.parser = parser
|
|
return parser.parse_args()
|
|
|
|
def print_options(self, opt):
|
|
"""Print and save options
|
|
|
|
It will print both current options and default values(if different).
|
|
It will save options into a text file / [checkpoints_dir] / opt.txt
|
|
"""
|
|
message = ''
|
|
message += '----------------- Options ---------------\n'
|
|
for k, v in sorted(vars(opt).items()):
|
|
comment = ''
|
|
default = self.parser.get_default(k)
|
|
if v != default:
|
|
comment = '\t[default: %s]' % str(default)
|
|
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
|
message += '----------------- End -------------------'
|
|
print(message)
|
|
|
|
# save to the disk
|
|
if opt.isTrain:
|
|
curr=time.strftime('%y%m%d-%H%M%S')
|
|
expr_dir = os.path.join(opt.checkpoints_dir, opt.model+'_'+opt.mod+'_'+str(curr))
|
|
opt.expr_dir = expr_dir
|
|
|
|
utils.mkdir(expr_dir)
|
|
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
|
with open(file_name, 'wt') as opt_file:
|
|
opt_file.write(message)
|
|
opt_file.write('\n')
|
|
|
|
def parse(self):
|
|
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
|
opt = self.gather_options()
|
|
opt.isTrain = self.isTrain # train or test
|
|
|
|
# process opt.suffix
|
|
if opt.suffix:
|
|
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
|
opt.name = opt.name + suffix
|
|
|
|
self.print_options(opt)
|
|
|
|
# set device with gpu id
|
|
if opt.gpu_id == -1:
|
|
opt.device = 'cpu'
|
|
else:
|
|
opt.device = f'cuda:{opt.gpu_id}' if torch.cuda.is_available() else 'cpu'
|
|
|
|
self.opt = opt
|
|
return self.opt
|
|
|
|
|