Newer
Older
#!/usr/bin/env python
from dagflow.lib import Array
from dagflow.parameters import GaussianParameters
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from dagflow.graph import Graph
from dagflow.graphviz import savegraph
from dagflow.exception import CriticalError
from numpy import square, allclose
import pytest
@pytest.mark.parametrize('mode', ('single', 'uncorr', 'cov', 'cov1d'))
def test_variables_00_variable(mode) -> None:
value_in = [1.1, 1.8, 5.0]
central_in = [1.0, 2.0, 3.0]
sigma_in = [1.0, 0.5, 2.0]
corrs_in = [-0.1, 0.5, -0.9] # 01, 02, 12
variance_in = square(sigma_in)
zeros_in = [0.0, 0.0, 0.0]
if mode=='single':
value_in = value_in[:1]
central_in = central_in[:1]
sigma_in = sigma_in[:1]
zeros_in = zeros_in[:1]
with Graph(debug=False, close=False) as graph:
value = Array("variable", value_in, mode='store_weak', mark='v')
central = Array("central", central_in, mark='v₀')
if mode in ('single', 'uncorr', 'cor'):
sigma = Array("sigma", sigma_in, mark='σ')
if mode in ('single', 'uncorr'):
gp = GaussianParameters(value, central, sigma=sigma)
elif mode=='cov':
covariance = Array("covariance", [
[variance_in[0], corrs_in[0]*sigma_in[0]*sigma_in[1], corrs_in[1]*sigma_in[0]*sigma_in[2]],
[corrs_in[0]*sigma_in[0]*sigma_in[1], variance_in[1], corrs_in[2]*sigma_in[1]*sigma_in[2]],
[corrs_in[1]*sigma_in[0]*sigma_in[2], corrs_in[2]*sigma_in[1]*sigma_in[2], variance_in[2]]
],
mark='V')
gp = GaussianParameters(value, central, covariance=covariance)
elif mode=='cov1d':
covariance = Array("covariance", variance_in, mark='diag(V)')
gp = GaussianParameters(value, central, covariance=covariance)
elif mode=='cor':
correlation = Array("correlation", [
[1.0, corrs_in[0], corrs_in[1]],
[corrs_in[0], 1.0, corrs_in[2]],
[corrs_in[1], corrs_in[2], 1.0],
], mark='C')
gp = GaussianParameters(value, central, sigma=sigma, correlation=correlation)
else:
raise RuntimeError(f"Invalid mode {mode}")
try:
graph.close()
except CriticalError as error:
savegraph(graph, f"output/test_variables_00_{mode}.png")
raise error
value_out0 = gp.value.data.copy()
normvalue_out0 = gp.constraint.normvalue.data
assert allclose(value_in, value_out0, atol=0, rtol=0)
assert all(normvalue_out0!=0)
gp.constraint.normvalue.set(zeros_in)
value_out1 = gp.value.data
normvalue_out1 = gp.constraint.normvalue.data
assert allclose(central_in, value_out1, atol=0, rtol=0)
assert allclose(normvalue_out1, 0.0, atol=0, rtol=0)
gp.value.set(value_out0)
value_out2 = gp.value.data
normvalue_out2 = gp.constraint.normvalue.data
assert allclose(value_in, value_out2, atol=0, rtol=0)
assert allclose(normvalue_out2, normvalue_out0, atol=0, rtol=0)
savegraph(graph, f"output/test_variables_00_{mode}.png", show=['all'])
savegraph(graph, f"output/test_variables_00_{mode}.pdf", show=['all'])