import argparse import os import time from utils import utils import torch class BaseOptions(): """This class defines options used during both training and test time. It also implements several helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in functions in both dataset class and model class. """ def __init__(self): """Reset the class; indicates the class hasn't been initailized""" self.initialized = False def initialize(self, parser): """Define the common options that are used in both training and test.""" # basic parameters parser.add_argument('--dataroot', type=str, default='./datasets', help='root path to datasets') parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') parser.add_argument('--gpu_id', type=str, default='0', help='gpu ids: e.g. 0, 1, ... . use -1 for CPU') parser.add_argument('--device', type=str, default='cuda:0', help='generate device with gpu_id and usable user device') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') parser.add_argument('--ms_ratio', type=int, default=5, help='multiscale ratio') parser.add_argument('--results_dir', type=str, default='./results', help='saves results here.') # model parameters parser.add_argument('--model', type=str, default='ANN', help='chooses which model to use. [ANN | CNN | AutoEncoder]') # dataset parameters # parser.add_argument('--resolution', type=str, default='180_60', help='data resolution. nelx_nely here') parser.add_argument('--nelx', type=int, default=180, help='num of elements on x-axis') parser.add_argument('--nely', type=int, default=60, help='num of elements on y-axis') parser.add_argument('--nelz', type=int, default=0, help='num of elements on z-axis') parser.add_argument('--dimension', type=int, default=2, help='dimension of dataset models') parser.add_argument('--is_standard', type=bool, default=True, help='whether need standardization or not') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--load_iter', type=int, default=0, help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', type=str, default='', help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') # identify initializiation timing self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser() # customize help formatting with parser = self.initialize(parser) # get the basic options opt, _ = parser.parse_known_args() # save and return the parser self.parser = parser return parser.parse_args() def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk if opt.isTrain: curr=time.strftime('%y%m%d-%H%M%S') expr_dir = os.path.join(opt.checkpoints_dir, opt.model+'_'+opt.mod+'_'+str(curr)) opt.expr_dir = expr_dir utils.mkdir(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix self.print_options(opt) # set device with gpu id if opt.gpu_id == -1: opt.device = 'cpu' else: opt.device = f'cuda:{opt.gpu_id}' if torch.cuda.is_available() else 'cpu' self.opt = opt return self.opt