Skip to content
Snippets Groups Projects
Commit 659bd04b authored by Maxim Gonchar's avatar Maxim Gonchar
Browse files

Squashed 'dagflow/' changes from 5932df6..9890644

9890644 feat: load_variables, enable loading the files
7ace3d4 feat: more schema modules

git-subtree-dir: dagflow
git-subtree-split: 9890644e7c02c187bbb5f614433a22b93468851b
parent 857db5ee
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ from dictwrapper.dictwrapper import DictWrapper
from schema import Schema, Or, Optional, Use, And, Schema, SchemaError
from ..tools.schema import NestedSchema
from ..tools.schema import NestedSchema, LoadFileWithExt, LoadYaml
class ParsCfgHasProperFormat(object):
def validate(self, data: dict) -> dict:
......@@ -67,8 +67,12 @@ IsVarsCfgDict = Schema({
'labels': IsLabelsDict,
'format': IsFormat
})
IsProperVarsCfg = And(IsVarsCfgDict, ParsCfgHasProperFormat())
IsProperVarsCfgDict = And(IsVarsCfgDict, ParsCfgHasProperFormat())
IsLoadableDict = And(
{'load': 'str'},
Use(LoadFileWithExt(yaml=LoadYaml))
)
IsProperVarsCfg = Or(IsProperVarsCfgDict, And(IsLoadableDict, IsProperVarsCfgDict))
def process_var_fixed1(vcfg, _, __):
return {'central': vcfg, 'value': vcfg, 'sigma': None}
......@@ -131,7 +135,7 @@ def iterate_varcfgs(cfg: DictWrapper):
from dagflow.variable import Parameters
def load_variables(acfg):
cfg = IsProperVarsCfg.validate(acfg)
cfg = IsProperVarsCfgDict.validate(acfg)
cfg = DictWrapper(cfg)
ret = DictWrapper({}, sep='.')
......
......@@ -2,8 +2,36 @@ from typing import Any, Union
from schema import Schema, Schema, SchemaError
from contextlib import suppress
from dictwrapper.dictwrapper import DictWrapper
from os import access, R_OK
from typing import Callable
def IsReadable(filename: str):
"""Returns True if the file is readable"""
return access(filename, R_OK)
def IsFilewithExt(*exts: str):
"""Returns a function that retunts True if the file extension is consistent"""
def checkfilename(filename: str):
return any(filename.endswith(f'.{ext}' for ext in exts))
return checkfilename
def LoadFileWithExt(*, key: str=None,**kwargs: Callable):
"""Returns a function that retunts True if the file extension is consistent"""
def checkfilename(filename: Union[str, dict]):
if key is not None:
filename = filename[key]
for ext, loader in kwargs.items():
if filename.endswith(f'.{ext}'):
return loader(filename)
return False
return checkfilename
from yaml import load, Loader
def LoadYaml(fname: str):
return load(fname, Loader)
from dictwrapper.dictwrapper import DictWrapper
class NestedSchema(object):
__slots__ = ('_schema', '_processdicts')
_schema: Union[Schema,object]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment