该项目是《Problem-independent machine learning (PIML)-based topology optimization—A universal approach》的python复现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

116 lines
5.4 KiB

import argparse
import os
import time
from utils import utils
import torch
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', type=str, default='./datasets', help='root path to datasets')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_id', type=str, default='0', help='gpu ids: e.g. 0, 1, ... . use -1 for CPU')
parser.add_argument('--device', type=str, default='cuda:0', help='generate device with gpu_id and usable user device')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--ms_ratio', type=int, default=5, help='multiscale ratio')
parser.add_argument('--results_dir', type=str, default='./results', help='saves results here.')
# model parameters
parser.add_argument('--model', type=str, default='ANN', help='chooses which model to use. [ANN | CNN | AutoEncoder]')
# dataset parameters
# parser.add_argument('--resolution', type=str, default='180_60', help='data resolution. nelx_nely here')
parser.add_argument('--nelx', type=int, default=180, help='num of elements on x-axis')
parser.add_argument('--nely', type=int, default=60, help='num of elements on y-axis')
parser.add_argument('--nelz', type=int, default=0, help='num of elements on z-axis')
parser.add_argument('--dimension', type=int, default=2, help='dimension of dataset models')
parser.add_argument('--is_standard', type=bool, default=True, help='whether need standardization or not')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default=0, help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', type=str, default='', help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
# identify initializiation timing
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser() # customize help formatting with <formatter_class>
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
if opt.isTrain:
curr=time.strftime('%y%m%d-%H%M%S')
expr_dir = os.path.join(opt.checkpoints_dir, opt.model+'_'+opt.mod+'_'+str(curr))
opt.expr_dir = expr_dir
utils.mkdir(expr_dir)
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set device with gpu id
if opt.gpu_id == -1:
opt.device = 'cpu'
else:
opt.device = f'cuda:{opt.gpu_id}' if torch.cuda.is_available() else 'cpu'
self.opt = opt
return self.opt