From 64d5de2b86e687a72eb2093f5a66443b26e0e0a6 Mon Sep 17 00:00:00 2001 From: Birdylx <29754889+Birdylx@users.noreply.github.com> Date: Mon, 30 Jan 2023 08:45:53 +0000 Subject: [PATCH] support cfg override --- paddle3d/apis/config.py | 38 +++++++++++++++++++++++++++++++------- tools/train.py | 5 ++++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/paddle3d/apis/config.py b/paddle3d/apis/config.py index c996b510..8e9926a9 100644 --- a/paddle3d/apis/config.py +++ b/paddle3d/apis/config.py @@ -14,6 +14,8 @@ import codecs import os +import six +from ast import literal_eval from collections.abc import Iterable, Mapping from typing import Any, Dict, Generic, Optional @@ -78,10 +80,11 @@ def __init__(self, else: raise RuntimeError('Config file should in yaml format!') - self.update(learning_rate=learning_rate, - batch_size=batch_size, - iters=iters, - epochs=epochs) + self.update( + learning_rate=learning_rate, + batch_size=batch_size, + iters=iters, + epochs=epochs) def _update_dic(self, dic: Dict, base_dic: Dict): '''Update config from dic based base_dic @@ -120,7 +123,8 @@ def update(self, learning_rate: Optional[float] = None, batch_size: Optional[int] = None, iters: Optional[int] = None, - epochs: Optional[int] = None): + epochs: Optional[int] = None, + opts: Optional[list] = None): '''Update config''' if learning_rate is not None: @@ -135,6 +139,26 @@ def update(self, if epochs is not None: self.dic['epochs'] = epochs + if opts is not None: + if len(opts) % 2 != 0 or len(opts) == 0: + raise ValueError( + "Command line options config `--opts` format error! It should be even length like: k1 v1 k2 v2 ... Please check it: {}" + .format(opts)) + for key, value in zip(opts[0::2], opts[1::2]): + if isinstance(value, six.string_types): + try: + value = literal_eval(value) + except ValueError: + pass + except SyntaxError: + pass + key_list = key.split('.') + dic = self.dic + for subkey in key_list[:-1]: + dic.setdefault(subkey, dict()) + dic = dic[subkey] + dic[key_list[-1]] = value + @property def batch_size(self) -> int: return self.dic.get('batch_size', 1) @@ -282,8 +306,8 @@ def _load_object(self, obj: Generic, recursive: bool = True) -> Any: if recursive: params = {} for key, val in dic.items(): - params[key] = self._load_object(obj=val, - recursive=recursive) + params[key] = self._load_object( + obj=val, recursive=recursive) else: params = dic try: diff --git a/tools/train.py b/tools/train.py index 676c1119..7075352a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -57,6 +57,8 @@ def parse_args(): help='epochs for training', type=int, default=None) + parser.add_argument( + '--opts', help='override config options.', default=None, nargs='+') parser.add_argument( '--keep_checkpoint_max', dest='keep_checkpoint_max', @@ -154,7 +156,8 @@ def main(args): learning_rate=args.learning_rate, batch_size=args.batch_size, iters=args.iters, - epochs=args.epochs) + epochs=args.epochs, + opts=args.opts) if cfg.train_dataset is None: raise RuntimeError(