百木园-与人分享,
就是让自己快乐。

使用argparse进行调参

argparse是深度学习项目调参时常用的python标准库,使用argparse后,我们在命令行输入的参数就可以以这种形式python filename.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。,一般使用时可以归纳为以下三个步骤

使用步骤:

  • 创建ArgumentParser()对象
  • 调用add_argument()方法添加参数
  • 使用parse_args()解析参数 在接下来的内容中,我们将以实际操作来学习argparse的使用方法
import argparse

parser = argparse.ArgumentParser() # 创建一个解析对象

parser.add_argument() # 向该对象中添加你要关注的命令行参数和选项

args = parser.parse_args() # 调用parse_args()方法进行解析

常见规则

  • 在命令行中输入python demo.py -h或者python demo.py --help可以查看该python文件参数说明
  • arg字典类似python字典,比如arg字典Namespace(integers=\'5\')可使用arg.参数名来提取这个参数
  • parser.add_argument(\'integers\', type=str, nargs=\'+\',help=\'传入的数字\') nargs是用来说明传入的参数个数,\'+\' 表示传入至少一个参数,\'*\' 表示参数可设置零个或多个,\'?\' 表示参数可设置零个或一个
  • parser.add_argument(\'-n\', \'--name\', type=str, required=True, default=\'\', help=\'名\') required=True表示必须参数, -n表示可以使用短选项使用该参数
  • parser.add_argument(\"--test_action\", default=\'False\', action=\'store_true\')store_true 触发时为真,不触发则为假(test.py,输出为 Falsetest.py --test_action,输出为 True

使用config文件传入超参数

为了使代码更加简洁和模块化,可以将有关超参数的操作写在config.py,然后在train.py或者其他文件导入就可以。具体的config.py可以参考如下内容。

import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
  
    parser.add_argument(\'--workers\', type=int, default=0,  
                        help=\'number of data loading workers, you had better put it \'  
                              \'4 times of your gpu\')  
  
    parser.add_argument(\'--batch_size\', type=int, default=4, help=\'input batch size, default=64\')  
  
    parser.add_argument(\'--niter\', type=int, default=10, help=\'number of epochs to train for, default=10\')  
  
    parser.add_argument(\'--lr\', type=float, default=3e-5, help=\'select the learning rate, default=1e-3\')  
  
    parser.add_argument(\'--seed\', type=int, default=118, help=\"random seed\")  
  
    parser.add_argument(\'--cuda\', action=\'store_true\', default=True, help=\'enables cuda\')  
    parser.add_argument(\'--checkpoint_path\',type=str,default=\'\',  
                        help=\'Path to load a previous trained model if not empty (default empty)\')  
    parser.add_argument(\'--output\',action=\'store_true\',default=True,help=\"shows output\")  
  
    opt = parser.parse_args()  
  
    if opt.output:  
        print(f\'num_workers: {opt.workers}\')  
        print(f\'batch_size: {opt.batch_size}\')  
        print(f\'epochs (niters) : {opt.niter}\')  
        print(f\'learning rate : {opt.lr}\')  
        print(f\'manual_seed: {opt.seed}\')  
        print(f\'cuda enable: {opt.cuda}\')  
        print(f\'checkpoint_path: {opt.checkpoint_path}\')  
  
    return opt  
  
if __name__ == \'__main__\':  
    opt = get_options()
$ python config.py

num_workers: 0
batch_size: 4
epochs (niters) : 10
learning rate : 3e-05
manual_seed: 118
cuda enable: True
checkpoint_path:

随后在train.py等其他文件,我们就可以使用下面的这样的结构来调用参数。

# 导入必要库
...
import config

opt = config.get_options()

manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path

# 随机数的设置,保证复现结果
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

...


if __name__ == \'__main__\':
  set_seed(manual_seed)
  for epoch in range(niters):
    train(model,lr,batch_size,num_workers,checkpoint_path)
    val(model,lr,batch_size,num_workers,checkpoint_path)

参考:

https://zhuanlan.zhihu.com/p/56922793

(14条消息) python argparse中action的可选参数store_true的作用_元气少女wuqh的博客-CSDN博客

[6.6 使用argparse进行调参 — 深入浅出PyTorch (datawhalechina.github.io)](https://datawhalechina.github.io/thorough-pytorch/第六章/6.6 使用argparse进行调参.html)


来源:https://www.cnblogs.com/qftie/p/16319150.html
本站部分图文来源于网络,如有侵权请联系删除。

未经允许不得转载:百木园 » 使用argparse进行调参

相关推荐

  • 暂无文章