diff --git a/conftest.py b/conftest.py index 1e7b36a0544b50a29f1af6cfdeef050ee9cc1d22..cb4f3e09a5847ac76d74ad2db570618dcdbf37ab 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,7 @@ -from os import chdir, getcwd, mkdir +from os import chdir, getcwd, mkdir, listdir from os.path import isdir -from pytest import fixture, skip +from pytest import fixture def pytest_sessionstart(session): @@ -11,19 +11,14 @@ def pytest_sessionstart(session): Automatic change path to the `dag-flow/test` and create `test/output` dir """ - path = getcwd() - lastdir = path.split("/")[-1] - if lastdir == "dag-flow": # rootdir - chdir("./test") - elif lastdir in ( - "dagflow", - "example", - "doc", - "docs", - "source", - "sources", - ): # childdir - chdir("../test") + while(path := getcwd()): + if (lastdir := path.split("/")[-1]) == "test": + break + elif ".git" in listdir(path): + chdir("./test") + break + else: + chdir("..") if not isdir("output"): mkdir("output") diff --git a/dagflow/bundles/load_variables.py b/dagflow/bundles/load_variables.py index e6556d78dd393c6ff23da61d97d07f5d9cb7eb39..fb75bf81ef1675d4a7444bd5426af01aa87558c2 100644 --- a/dagflow/bundles/load_variables.py +++ b/dagflow/bundles/load_variables.py @@ -66,7 +66,8 @@ IsVarsCfgDict = Schema({ 'variables': IsValuesDict, 'labels': IsLabelsDict, 'format': IsFormat, - 'state': Or('fixed', 'variable', error='Invalid parameters state: {}') + 'state': Or('fixed', 'variable', error='Invalid parameters state: {}'), + Optional('path', default=''): str }, # error = 'Invalid parameters configuration: {}' ) @@ -142,21 +143,30 @@ def load_variables(acfg): cfg = IsProperVarsCfg.validate(acfg) cfg = DictWrapper(cfg) + path = cfg['path'] + if path: + path = path.split('.') + else: + path = () + + state = cfg['state'] + ret = DictWrapper({'constants': {}, 'free': {}, 'constrained': {}}, sep='.') for key, varcfg in iterate_varcfgs(cfg): skey = '.'.join(key) label = varcfg['label'] label['key'] = skey label.setdefault('text', skey) + varcfg.setdefault(state, True) par = Parameters.from_numbers(**varcfg) if par.is_constrained: - target = ret['constrained'] + target = ('constrained',) + path elif par.is_fixed: - target = ret['constants'] + target = ('constants',) + path else: - target = ret['free'] + target = ('free',) + path - target[key] = par + ret[target+key] = par return ret diff --git a/dagflow/variable.py b/dagflow/variable.py index 0bae37a4d79804cd46c6c9b3fb2f0bd0c5467816..a28f494d04e07c99816fa88b83a16700f1402598 100644 --- a/dagflow/variable.py +++ b/dagflow/variable.py @@ -53,24 +53,27 @@ class Parameters(object): @staticmethod def from_numbers(*, dtype: DTypeLike='d', **kwargs) -> 'Parameters': - sigma = kwargs['sigma'] + sigma = kwargs.pop('sigma') if sigma is not None: - return GaussianParameters.from_numbers(dtype=dtype, **kwargs) + return GaussianParameters.from_numbers(dtype=dtype, sigma=sigma, **kwargs) - label: Dict[str, str] = kwargs.get('label') + del kwargs['central'] + + label: Dict[str, str] = kwargs.pop('label', None) if label is None: label = {'text': 'parameter'} else: label = dict(label) name: str = label.setdefault('name', 'parameter') - value = kwargs['value'] + value = kwargs.pop('value') return Parameters( Array( name, array((value,), dtype=dtype), label = label, - mode='store_weak' - ) + mode='store_weak', + ), + **kwargs ) class GaussianParameters(Parameters): diff --git a/test/variables/test_load_variables.py b/test/variables/test_load_variables.py index 73c491ee76c760fd7854606081b70ce14a4f0583..60b7d4691b5d96ad017b4b2648ff03b58cb21000 100644 --- a/test/variables/test_load_variables.py +++ b/test/variables/test_load_variables.py @@ -21,6 +21,25 @@ cfg1 = { 'var2': 'simple label 2', }, } +cfg1a = { + 'variables': { + 'var1': 1.0, + 'var2': 1.0, + 'sub1': { + 'var3': 2.0 + } + }, + 'format': 'value', + 'state': 'fixed', + 'labels': { + 'var1': { + 'text': 'text label 1', + 'latex': r'\LaTeX label 1', + 'name': 'v1-1' + }, + 'var2': 'simple label 2', + }, + } cfg2 = { 'variables': { @@ -100,10 +119,14 @@ cfg5 = { 'state': 'variable', } +from pprint import pprint def test_load_variables_v01(): - cfgs = (cfg1, cfg2, cfg3, cfg4, cfg5) + cfgs = (cfg1, cfg1a, cfg2, cfg3, cfg4, cfg5) with Graph(close=True) as g: - for cfg in cfgs: - load_variables(cfg) + for i, cfg in enumerate(cfgs): + vars = load_variables(cfg) + print(cfg['state']) + print(i, end=' ') + pprint(vars.object) savegraph(g, 'output/test_load_variables.pdf', show='all')