diff --git a/subtrees/dagflow/dagflow/bundles/load_parameters.py b/subtrees/dagflow/dagflow/bundles/load_parameters.py
index 6882ff5dd6158dd96da427c59824f5cbc7dfe670..103b60a207dd74935feeb50a2e084589cda8e6fa 100644
--- a/subtrees/dagflow/dagflow/bundles/load_parameters.py
+++ b/subtrees/dagflow/dagflow/bundles/load_parameters.py
@@ -146,14 +146,16 @@ def iterate_varcfgs(cfg: NestedMKDict):
yield key, varcfg
from dagflow.variable import Parameters
+from dagflow.lib.SumSq import SumSq
+from dagflow.lib.Sum import Sum
def load_parameters(acfg):
cfg = ValidateParsCfg(acfg)
cfg = NestedMKDict(cfg)
- path = cfg['path']
- if path:
- path = tuple(path.split('.'))
+ pathstr = cfg['path']
+ if pathstr:
+ path = tuple(pathstr.split('.'))
else:
path = ()
@@ -165,7 +167,11 @@ def load_parameters(acfg):
'constant': {},
'free': {},
'constrained': {},
- 'normalized': {}
+ 'normalized': {},
+ },
+ 'stat': {
+ 'nuisance_parts': {},
+ 'nuisance': {},
},
'parameter_node': {
'constant': {},
@@ -175,6 +181,8 @@ def load_parameters(acfg):
},
sep='.'
)
+
+ normpars = []
for key, varcfg in iterate_varcfgs(cfg):
skey = '.'.join(key)
label = varcfg['label']
@@ -184,20 +192,37 @@ def load_parameters(acfg):
par = Parameters.from_numbers(**varcfg)
if par.is_constrained:
- target = ('constrained',) + path
+ target = ('constrained', path)
elif par.is_fixed:
- target = ('constant',) + path
+ target = ('constant', path)
else:
- target = ('free',) + path
+ target = ('free', path)
ret[('parameter_node',)+target+key] = par
- ptarget = ('parameter',)+target
+ ptarget = ('parameter', target)
for subpar in par.parameters:
ret[ptarget+key] = subpar
- ntarget = ('parameter', 'normalized')+path
+ ntarget = ('parameter', 'normalized', path)
for subpar in par.norm_parameters:
ret[ntarget+key] = subpar
+ normpars.append(subpar)
+
+ if normpars:
+ ssq = SumSq(f'nuisance for {pathstr}')
+ (n.output for n in normpars) >> ssq
+ ssq.close()
+ ret[('stat', 'nuisance_parts', path)] = ssq
+
+ nuisanceall = ret.get('stat.nuisance.all', None)
+ if nuisanceall is None:
+ nuisanceall = Sum('nuisance total')
+ ret['stat.nuisance.all'] = nuisanceall
+ else:
+ nuisanceall.open()
+ ssq >> nuisanceall
+ nuisanceall.close()
+
return ret