Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from typing import Optional, Union
class SimpleFormatter():
_base: str
_numfmt: str
def __init__(self, base: str, numfmt: str = '_{:02d}'):
self._base = base
self._numfmt = numfmt
@staticmethod
def from_string(string: str):
if '{' in string:
return string
return SimpleFormatter(string)
def format(self, num: int) -> str:
if num>0:
return self._base+self._numfmt.format(num)
return self._base
class MissingInputHandler:
"""
Handler to implement behaviour when output
is connected to the missing input with >>/<<
"""
_node = None
def __init__(self, node=None):
self.node = node
@property
def node(self):
return self._node
@node.setter
def node(self, node):
self._node = node
def __call__(self, idx=None, scope=None):
pass
class MissingInputFail(MissingInputHandler):
"""Default missing input handler: issues and exception"""
def __init__(self, node=None):
super().__init__(node)
def __call__(self, idx=None, scope=None):
raise RuntimeError(
"Unable to iterate inputs further. "
"No additional inputs may be created"
)
class MissingInputAdd(MissingInputHandler):
"""Adds an input for each output in >> operator"""
input_fmt: Union[str,SimpleFormatter] = SimpleFormatter("input", "_{:02d}")
input_kws: dict
output_fmt: Union[str,SimpleFormatter] = SimpleFormatter("output", "_{:02d}")
output_kws: dict
def __init__(
self,
node=None,
*,
input_fmt: Optional[Union[str,SimpleFormatter]] = None,
input_kws: Optional[dict] = None,
output_fmt: Optional[Union[str,SimpleFormatter]] = None,
output_kws: Optional[dict] = None,
):
if input_kws is None:
input_kws = {}
if output_kws is None:
output_kws = {}
super().__init__(node)
self.input_kws = input_kws
self.output_kws = output_kws
if input_fmt is not None:
self.input_fmt = SimpleFormatter.from_string(input_fmt)
if output_fmt is not None:
self.output_fmt = SimpleFormatter.from_string(output_fmt)
def __call__(self, idx=None, scope=None, **kwargs):
kwargs_combined = dict(self.input_kws, **kwargs)
return self.node._add_input(
self.input_fmt.format(
idx if idx is not None else len(self.node.inputs)
),
**kwargs_combined,
)
class MissingInputAddPair(MissingInputAdd):
"""
Adds an input for each output in >> operator.
Adds an output for each new input
"""
def __init__(self, node=None, **kwargs):
super().__init__(node, **kwargs)
def __call__(self, idx=None, scope=None):
idx_out = len(self.node.outputs)
out = self.node._add_output(
self.output_fmt.format(idx_out), **self.output_kws
)
return super().__call__(idx, child_output=out, scope=scope)
class MissingInputAddOne(MissingInputAdd):
"""
Adds an input for each output in >> operator.
Adds only one output if needed
"""
add_child_output = False
def __init__(self, node=None, *, add_child_output: bool = False, **kwargs):
super().__init__(node, **kwargs)
self.add_child_output = add_child_output
def __call__(self, idx=None, scope=None):
if (idx_out := len(self.node.outputs)) == 0:
out = self.node._add_output(
self.output_fmt.format(idx_out), **self.output_kws
)
else:
out = self.node.outputs[-1]
if self.add_child_output:
return super().__call__(idx, child_output=out, scope=scope)
return super().__call__(idx, scope=scope)
class MissingInputAddEach(MissingInputAdd):
"""
Adds an output for each block (for each >> operation)
"""
add_child_output = False
scope = 0
def __init__(self, node=None, *, add_child_output=False, **kwargs):
super().__init__(node, **kwargs)
self.add_child_output = add_child_output
def __call__(self, idx=None, scope=None):
if scope == self.scope != 0:
out = self.node.outputs[-1]
else:
out = self.node._add_output(
self.output_fmt.format(len(self.node.outputs)),
**self.output_kws,
)
self.scope = scope
if self.add_child_output:
return super().__call__(idx, child_output=out, scope=scope)
return super().__call__(idx, scope=scope)