Skip to content
Snippets Groups Projects
test_parameters.py 3.25 KiB
Newer Older
#!/usr/bin/env python

from dagflow.lib import Array
from dagflow.parameters import GaussianParameters
from dagflow.graph import Graph
from dagflow.graphviz import savegraph
from dagflow.exception import CriticalError

from numpy import square, allclose
import pytest

@pytest.mark.parametrize('mode', ('single', 'uncorr', 'cov', 'cov1d'))
def test_variables_00_variable(mode) -> None:
    value_in    = [1.1, 1.8, 5.0]
    central_in  = [1.0, 2.0, 3.0]
    sigma_in    = [1.0, 0.5, 2.0]
    corrs_in    = [-0.1, 0.5, -0.9] # 01, 02, 12
    variance_in = square(sigma_in)
    zeros_in    = [0.0, 0.0, 0.0]

    if mode=='single':
        value_in = value_in[:1]
        central_in = central_in[:1]
        sigma_in = sigma_in[:1]
        zeros_in = zeros_in[:1]

    with Graph(debug=False, close=False) as graph:
        value   = Array("variable", value_in, mode='store_weak', mark='v')
        central = Array("central",  central_in, mark='v₀')

        if mode in ('single', 'uncorr', 'cor'):
            sigma = Array("sigma", sigma_in, mark='σ')

        if mode in ('single', 'uncorr'):
            gp = GaussianParameters(value, central, sigma=sigma)
        elif mode=='cov':
            covariance = Array("covariance", [
                    [variance_in[0],                      corrs_in[0]*sigma_in[0]*sigma_in[1], corrs_in[1]*sigma_in[0]*sigma_in[2]],
                    [corrs_in[0]*sigma_in[0]*sigma_in[1], variance_in[1],                      corrs_in[2]*sigma_in[1]*sigma_in[2]],
                    [corrs_in[1]*sigma_in[0]*sigma_in[2], corrs_in[2]*sigma_in[1]*sigma_in[2], variance_in[2]]
                                ],
                               mark='V')
            gp = GaussianParameters(value, central, covariance=covariance)
        elif mode=='cov1d':
            covariance = Array("covariance", variance_in, mark='diag(V)')
            gp = GaussianParameters(value, central, covariance=covariance)
        elif mode=='cor':
            correlation = Array("correlation", [
                [1.0,         corrs_in[0], corrs_in[1]],
                [corrs_in[0], 1.0,         corrs_in[2]],
                [corrs_in[1], corrs_in[2], 1.0],
                ], mark='C')
            gp = GaussianParameters(value, central, sigma=sigma, correlation=correlation)
        else:
            raise RuntimeError(f"Invalid mode {mode}")

    try:
        graph.close()
    except CriticalError as error:
        savegraph(graph, f"output/test_variables_00_{mode}.png")
        raise error

    value_out0 = gp.value.data.copy()
    normvalue_out0 = gp.constraint.normvalue.data
    assert allclose(value_in, value_out0, atol=0, rtol=0)
    assert all(normvalue_out0!=0)

    gp.constraint.normvalue.set(zeros_in)
    value_out1 = gp.value.data
    normvalue_out1 = gp.constraint.normvalue.data
    assert allclose(central_in, value_out1, atol=0, rtol=0)
    assert allclose(normvalue_out1, 0.0, atol=0, rtol=0)

    gp.value.set(value_out0)
    value_out2 = gp.value.data
    normvalue_out2 = gp.constraint.normvalue.data
    assert allclose(value_in, value_out2, atol=0, rtol=0)
    assert allclose(normvalue_out2, normvalue_out0, atol=0, rtol=0)

    savegraph(graph, f"output/test_variables_00_{mode}.png", show=['all'])
    savegraph(graph, f"output/test_variables_00_{mode}.pdf", show=['all'])