from multikeydict.nestedmkdict import NestedMKDict
# from multikeydict.flatmkdict import FlatMKDict # To be used later
from gindex import GNIndex

from schema import Schema, Or, Optional, Use, And, Schema, SchemaError
from pathlib import Path

from ..tools.schema import NestedSchema, LoadFileWithExt, LoadYaml

class ParsCfgHasProperFormat(object):
    def validate(self, data: dict) -> dict:
        format = data['format']
        if isinstance(format, str):
            nelements = 1
        else:
            nelements = len(format)

        dtin = NestedMKDict(data)
        for key, subdata in dtin['parameters'].walkitems():
            if isinstance(subdata, tuple):
                if len(subdata)==nelements: continue
            else:
                if nelements==1: continue

            key = ".".join(str(k) for k in key)
            raise SchemaError(f'Key "{key}" has  value "{subdata}"" inconsistent with format "{format}"')

        return data

IsNumber = Or(float, int, error='Invalid number "{}", expect int of float')
IsNumberOrTuple = Or(IsNumber, (IsNumber,), And([IsNumber], Use(tuple)), error='Invalid number/tuple {}')
IsLabel = Or({
        'text': str,
        Optional('latex'): str,
        Optional('graph'): str,
        Optional('mark'): str,
        Optional('name'): str
    },
    And(str, Use(lambda s: {'text': s}), error='Invalid string: {}')
)
IsValuesDict = NestedSchema(IsNumberOrTuple)
IsLabelsDict = NestedSchema(IsLabel, processdicts=True)
def IsFormatOk(format):
    if not isinstance(format, (tuple, list)):
        return format=='value'

    if len(format)==1:
        f1,=format
        return f1=='value'
    else:
        if len(format)==2:
            f1, f3 = format
        elif len(format)==3:
            f1, f2, f3 = format

            if f2 not in ('value', 'central') or f1==f2:
                return False
        else:
            return False

        if f3 not in ('sigma_absolute', 'sigma_relative', 'sigma_percent'):
            return False

        return f1 in ('value', 'central')

IsFormat = Schema(IsFormatOk, error='Invalid parameter format "{}".')
IsStrSeq = (str,)
IsStrSeqOrStr = Or(IsStrSeq, And(str, Use(lambda s: (s,))))
IsParsCfgDict = Schema({
    'parameters': IsValuesDict,
    'labels': IsLabelsDict,
    'format': IsFormat,
    'state': Or('fixed', 'variable', error='Invalid parameters state: {}'),
    Optional('path', default=''): str,
    Optional('replicate', default=((),)): (IsStrSeqOrStr,),
    },
    # error = 'Invalid parameters configuration: {}'
)
IsProperParsCfgDict = And(IsParsCfgDict, ParsCfgHasProperFormat())
IsLoadableDict = And(
            {
                'load': Or(str, And(Path, Use(str))),
                Optional(str): object
            },
            Use(LoadFileWithExt(yaml=LoadYaml, key='load', update=True), error='Failed to load {}'),
            IsProperParsCfgDict
        )
def ValidateParsCfg(cfg):
    if isinstance(cfg, dict) and 'load' in cfg:
        return IsLoadableDict.validate(cfg)
    else:
        return IsProperParsCfgDict.validate(cfg)

def process_var_fixed1(vcfg, _, __):
    return {'central': vcfg, 'value': vcfg, 'sigma': None}

def process_var_fixed2(vcfg, format, hascentral) -> dict:
    ret = dict(zip(format, vcfg))
    if hascentral:
        ret.setdefault('value', ret['central'])
    else:
        ret.setdefault('central', ret['value'])
    ret['sigma'] = None
    return ret

def process_var_absolute(vcfg, format, hascentral) -> dict:
    ret = process_var_fixed2(vcfg, format, hascentral)
    ret['sigma'] = ret['sigma_absolute']
    return ret

def process_var_relative(vcfg, format, hascentral) -> dict:
    ret = process_var_fixed2(vcfg, format, hascentral)
    ret['sigma'] = ret['sigma_relative']*ret['central']
    return ret

def process_var_percent(vcfg, format, hascentral) -> dict:
    ret = process_var_fixed2(vcfg, format, hascentral)
    ret['sigma'] = 0.01*ret['sigma_percent']*ret['central']
    return ret

def get_format_processor(format):
    if isinstance(format, str):
        return process_var_fixed1

    errfmt = format[-1]
    if not errfmt.startswith('sigma'):
        return process_var_fixed2

    if errfmt.endswith('_absolute'):
        return process_var_absolute
    elif errfmt.endswith('_relative'):
        return process_var_relative
    else:
        return process_var_percent

def get_label(key: tuple, labelscfg: dict) -> dict:
    try:
        return labelscfg[key]
    except KeyError:
        pass

    for n in range(1, len(key)+1):
        subkey = key[:-n]
        try:
            lcfg = labelscfg[subkey]
        except KeyError:
            continue

        sidx = '.'.join(key[n-1:])
        return {k: v.format(sidx) for k, v in lcfg.items()}

    return {}

def iterate_varcfgs(cfg: NestedMKDict):
    parameterscfg = cfg['parameters']
    labelscfg = cfg['labels']
    format = cfg['format']

    hascentral = 'central' in format
    process = get_format_processor(format)

    for key, varcfg in parameterscfg.walkitems():
        varcfg = process(varcfg, format, hascentral)
        varcfg['label'] = get_label(key, labelscfg)
        yield key, varcfg

from dagflow.parameters import Parameters
from dagflow.lib.SumSq import SumSq

def load_parameters(acfg):
    cfg = ValidateParsCfg(acfg)
    cfg = NestedMKDict(cfg)

    pathstr = cfg['path']
    if pathstr:
        path = tuple(pathstr.split('.'))
    else:
        path = ()

    state = cfg['state']

    ret = NestedMKDict(
        {
            'parameter': {
                'constant': {},
                'free': {},
                'constrained': {},
                'normalized': {},
                },
            'stat': {
                'nuisance_parts': {},
                'nuisance': {},
                },
            'parameter_node': {
                'constant': {},
                'free': {},
                'constrained': {}
                }
        },
        sep='.'
    )

    subkeys = cfg['replicate']

    normpars = {}
    for key_general, varcfg in iterate_varcfgs(cfg):
        key_general_str = '.'.join(key_general)
        varcfg.setdefault(state, True)

        normpars_i = normpars.setdefault(key_general[0], [])
        for subkey in subkeys:
            key = key_general + subkey
            key_str = '.'.join(key)
            subkey_str = '.'.join(subkey)

            label = varcfg['label'].copy()
            label['key'] = key_str
            label.setdefault('text', key_str)

            par = Parameters.from_numbers(**varcfg)
            if par.is_constrained:
                target = ('constrained', path)
            elif par.is_fixed:
                target = ('constant', path)
            else:
                target = ('free', path)

            ret[('parameter_node',)+target+key] = par

            ptarget = ('parameter', target)
            for subpar in par.parameters:
                ret[ptarget+key] = subpar

            ntarget = ('parameter', 'normalized', path)
            for subpar in par.norm_parameters:
                ret[ntarget+key] = subpar

                normpars_i.append(subpar)

        for name, np in normpars.items():
            if np:
                ssq = SumSq(f'nuisance for {pathstr}.{name}')
                (n.output for n in np) >> ssq
                ssq.close()
                ret[('stat', 'nuisance_parts', path, name)] = ssq

    return ret