Loading [MathJax]/extensions/TeX/AMSmath.js
Radium Engine  1.5.29
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
graph_as_node.cpp
1#include <catch2/catch_test_macros.hpp>
2
3#include <Core/Containers/MakeShared.hpp>
4#include <Dataflow/Core/DataflowGraph.hpp>
5#include <Dataflow/Core/Functionals/BinaryOpNode.hpp>
6#include <Dataflow/Core/Functionals/FilterNode.hpp>
7#include <Dataflow/Core/Functionals/FunctionNode.hpp>
8#include <Dataflow/Core/Functionals/ReduceNode.hpp>
9#include <Dataflow/Core/Functionals/TransformNode.hpp>
10#include <Dataflow/Core/Sinks/SinkNode.hpp>
11#include <Dataflow/Core/Sinks/Types.hpp>
12#include <Dataflow/Core/Sources/Types.hpp>
13
14using namespace Ra::Dataflow::Core;
15using namespace Ra::Core;
16
17TEST_CASE( "Dataflow/Core/GraphAsNode/Delta", "[unittests][Dataflow][Core][Graph]" ) {
18 auto port_fatcory = PortFactory::getInstance();
19 port_fatcory->add_port_type<Scalar>();
20
21 auto gAsNode = make_shared<DataflowGraph>( "graphAsNode" );
22
23 // compute delta = b2 - 4ac;
24
25 auto b2 = gAsNode->add_node<Functionals::FunctionNode<Scalar>>( "b2" );
26 b2->set_function( []( const Scalar& b ) { return b * b; } );
27
28 auto fourAC = gAsNode->add_node<Functionals::BinaryOpNode<Scalar>>(
29 "4ac", []( const Scalar& a, const Scalar& c ) { return 4_ra * a * c; } );
30
31 auto b2minus4ac = gAsNode->add_node<Functionals::BinaryOpNode<Scalar>>(
32 "b2-4ac", []( const Scalar& x, const Scalar& y ) { return x - y; } );
33
34 auto forwardA = gAsNode->add_node<Functionals::FunctionNode<Scalar>>( "a" );
35 auto forwardB = gAsNode->add_node<Functionals::FunctionNode<Scalar>>( "b" );
36 auto forwardC = gAsNode->add_node<Functionals::FunctionNode<Scalar>>( "c" );
37
38 b2minus4ac->port_out_result()->set_name( "delta" );
39
40 REQUIRE( !gAsNode->input_node() );
41 REQUIRE( !gAsNode->output_node() );
42
43 gAsNode->add_input_output_nodes();
44 auto inputA = gAsNode->input_node()->add_output_port( forwardA->port_in_data().get() );
45 auto inputB = gAsNode->input_node()->add_output_port( forwardB->port_in_data().get() );
46 auto inputC = gAsNode->input_node()->add_output_port( forwardC->port_in_data().get() );
47 auto output = gAsNode->output_node()->add_input_port( b2minus4ac->port_out_result().get() );
48
49 REQUIRE( gAsNode->input_node() );
50 REQUIRE( gAsNode->output_node() );
51 REQUIRE( gAsNode->input_node()->outputs().size() == 3 );
52 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
53 REQUIRE( gAsNode->inputs().size() == 0 );
54 REQUIRE( gAsNode->outputs().size() == 0 );
55
56 gAsNode->input_node()->input_by_index( inputA )->set_name( "a" );
57 gAsNode->input_node()->input_by_index( inputB )->set_name( "b" );
58 gAsNode->input_node()->input_by_index( inputC )->set_name( "c" );
59 gAsNode->output_node()->output_by_index( output )->set_name( "delta" );
60
61 REQUIRE( gAsNode->add_link( forwardB->port_out_result(), b2->port_in_data() ) );
62 REQUIRE( gAsNode->add_link( forwardA->port_out_result(), fourAC->port_in_a() ) );
63 REQUIRE( gAsNode->add_link( forwardC->port_out_result(), fourAC->port_in_b() ) );
64
65 REQUIRE( gAsNode->add_link( b2->port_out_result(), b2minus4ac->port_in_a() ) );
66 REQUIRE( gAsNode->add_link( fourAC->port_out_result(), b2minus4ac->port_in_b() ) );
67
68 REQUIRE( gAsNode->compile() );
69
70 REQUIRE( gAsNode->inputs().size() == 3 );
71 REQUIRE( gAsNode->outputs().size() == 1 );
72
73 DataflowGraph g { "mainGraph" };
74 auto sourceNodeA = g.add_node<Sources::SingleDataSourceNode<Scalar>>( "sa" );
75 auto sourceNodeB = g.add_node<Sources::SingleDataSourceNode<Scalar>>( "sb" );
76 auto sourceNodeC = g.add_node<Sources::SingleDataSourceNode<Scalar>>( "sc" );
77
78 auto resultNode = g.add_node<Sinks::SinkNode<Scalar>>( "odelta" );
79
80 REQUIRE( g.add_node( gAsNode ) );
81
82 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( inputA ) ) );
83 REQUIRE( g.add_link( sourceNodeB->port_out_to().get(), gAsNode->input_by_index( inputB ) ) );
84 REQUIRE( g.add_link( sourceNodeC->port_out_to().get(), gAsNode->input_by_index( inputC ) ) );
85 REQUIRE( g.add_link( gAsNode->output_by_index( output ), resultNode->port_in_from().get() ) );
86
87 sourceNodeA->set_data( 1 );
88 sourceNodeB->set_data( 2 );
89 sourceNodeC->set_data( 3 );
90
91 REQUIRE( g.compile() );
92 REQUIRE( g.execute() );
93
94 auto& result = resultNode->data_reference();
95 REQUIRE( result == -8 );
96}
97
98using PortIndex = Ra::Dataflow::Core::Node::PortIndex;
99using FunctionNode = Functionals::FunctionNode<Scalar>;
102
103TEST_CASE( "Dataflow/Core/GraphAsNode/Forward", "[unittests][Dataflow][Core][Graph]" ) {
104
105 auto port_fatcory = PortFactory::createInstance();
106 port_fatcory->add_port_type<Scalar>();
107
108 auto gAsNode = make_shared<DataflowGraph>( "graphAsNode" );
109 auto f = gAsNode->add_node<FunctionNode>( "f" );
110
111 REQUIRE( !gAsNode->input_node() );
112 REQUIRE( !gAsNode->output_node() );
113
114 gAsNode->add_input_output_nodes();
115
116 REQUIRE( gAsNode->input_node() );
117 REQUIRE( gAsNode->output_node() );
118
119 REQUIRE( !gAsNode->can_link(
120 gAsNode->input_node(), PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
121 REQUIRE( !gAsNode->add_link(
122 gAsNode->input_node(), PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
123
124 // add link to portIndex = input_node().size(), creates port on input_node()
125 REQUIRE( !gAsNode->can_link( gAsNode->input_node(), PortIndex { 10 }, f, PortIndex { 0 } ) );
126 REQUIRE( !gAsNode->can_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 10 } ) );
127 REQUIRE( gAsNode->can_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 0 } ) );
128
129 REQUIRE( !gAsNode->add_link( gAsNode->input_node(), PortIndex { 10 }, f, PortIndex { 0 } ) );
130 REQUIRE( !gAsNode->add_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 10 } ) );
131 REQUIRE( gAsNode->add_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 0 } ) );
132
133 REQUIRE( !gAsNode->can_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
134 REQUIRE( !gAsNode->can_link( f, PortIndex { 1 }, gAsNode->output_node(), PortIndex { 0 } ) );
135 REQUIRE( gAsNode->can_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
136
137 REQUIRE( !gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
138 REQUIRE( !gAsNode->add_link( f, PortIndex { 1 }, gAsNode->output_node(), PortIndex { 0 } ) );
139 REQUIRE( gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
140
141 REQUIRE( gAsNode->input_node() );
142 REQUIRE( gAsNode->output_node() );
143 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
144 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
145 REQUIRE( gAsNode->inputs().size() == 0 );
146 REQUIRE( gAsNode->outputs().size() == 0 );
147
148 gAsNode->generate_ports();
149
150 REQUIRE( gAsNode->inputs().size() == 1 );
151 REQUIRE( gAsNode->outputs().size() == 1 );
152
153 DataflowGraph g { "mainGraph" };
154
155 auto sourceNodeA = g.add_node<Source>( "s" );
156 auto resultNode = g.add_node<Sink>( "r" );
157
158 REQUIRE( g.add_node( gAsNode ) );
159 REQUIRE( g.can_link( sourceNodeA, PortIndex { 0 }, gAsNode, PortIndex { 0 } ) );
160 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( 0 ) ) );
161 REQUIRE( g.add_link( gAsNode->output_by_index( 0 ), resultNode->port_in_from().get() ) );
162
163 sourceNodeA->set_data( 2 );
164
165 SECTION( "Serialization" ) {
166
168 std::string tmpdir { "tmpDir4Tests" };
169 std::filesystem::create_directories( tmpdir );
170 REQUIRE( g.shouldBeSaved() );
171 g.saveToJson( tmpdir + "/graph_as_node_io.json" );
172 REQUIRE( !g.shouldBeSaved() );
173
174 // Create a new graph and load from the saved graph
175 DataflowGraph g1 { "loaded graph" };
176 REQUIRE( g1.loadFromJson( tmpdir + "/graph_as_node_io.json" ) );
177 {
178 auto g1_sourceNodeA = std::dynamic_pointer_cast<Source>( g1.node( "s" ) );
179
180 REQUIRE( g1_sourceNodeA );
181 auto g1_nodeGraph =
182 std::dynamic_pointer_cast<DataflowGraph>( g1.node( "graphAsNode" ) );
183 REQUIRE( g1_nodeGraph->display_name() == "graphAsNode" );
184 REQUIRE( g1_nodeGraph );
185 auto g1_resultNode = std::dynamic_pointer_cast<Sink>( g1.node( "r" ) );
186 REQUIRE( g1_resultNode );
187
188 REQUIRE( g1_sourceNodeA->data() );
189 REQUIRE( *g1_sourceNodeA->data() == 2 );
190 }
191 }
192
193 SECTION( "Remove unlinked" ) {
194
195 REQUIRE( g.compile() );
196 REQUIRE( g.execute() );
197
198 // first setup
199 gAsNode->remove_unlinked_input_output_ports();
200 REQUIRE( gAsNode->input_node() );
201 REQUIRE( gAsNode->output_node() );
202 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
203 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
204 REQUIRE( gAsNode->inputs().size() == 1 );
205 REQUIRE( gAsNode->outputs().size() == 1 );
206
207 // re link in gAsNode
208 REQUIRE( gAsNode->remove_link( f, PortIndex { 0 } ) );
209 REQUIRE( gAsNode->add_link( gAsNode->input_node(), PortIndex { 1 }, f, PortIndex { 0 } ) );
210 REQUIRE( gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
211 REQUIRE( gAsNode->remove_link( gAsNode->output_node(), PortIndex { 0 } ) );
212
213 // input/output node of gAsNode still linked in g, remove unlink do nothing
214 gAsNode->remove_unlinked_input_output_ports();
215
216 REQUIRE( gAsNode->input_node() );
217 REQUIRE( gAsNode->output_node() );
218 REQUIRE( gAsNode->input_node()->outputs().size() == 2 );
219 REQUIRE( gAsNode->output_node()->inputs().size() == 2 );
220 REQUIRE( gAsNode->inputs().size() == 2 );
221 REQUIRE( gAsNode->outputs().size() == 2 );
222
223 // unlinked in g, remove unlink clean up unused port in input/output nodes
224 REQUIRE( g.remove_link( gAsNode->input_node(), PortIndex { 0 } ) );
225 REQUIRE( g.remove_link( resultNode, PortIndex { 0 } ) );
226
227 gAsNode->remove_unlinked_input_output_ports();
228
229 // now cleaned
230 REQUIRE( gAsNode->input_node() );
231 REQUIRE( gAsNode->output_node() );
232 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
233 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
234 REQUIRE( gAsNode->inputs().size() == 1 );
235 REQUIRE( gAsNode->outputs().size() == 1 );
236
237 // relink
238 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( 0 ) ) );
239 REQUIRE( g.add_link( gAsNode->output_by_index( 0 ), resultNode->port_in_from().get() ) );
240
241 gAsNode->remove_unlinked_input_output_ports();
242
243 // no change
244 REQUIRE( gAsNode->input_node() );
245 REQUIRE( gAsNode->output_node() );
246 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
247 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
248 REQUIRE( gAsNode->inputs().size() == 1 );
249 REQUIRE( gAsNode->outputs().size() == 1 );
250 }
251}
Represent a set of connected nodes that define a Direct Acyclic Computational Graph Ownership of node...
virtual bool add_node(std::shared_ptr< Node > newNode)
Adds a node to the graph.
Apply a binary operation on its input.
Base class for nodes that will store the result of a computation graph.
Definition SinkNode.hpp:17
Base class for nodes that will give access to some input data to the graph. This class can be used to...
void set_data(T data)
Set the data to be delivered by the node.
This namespace contains everything "low level", related to data, datastuctures, and computation.
Definition Cage.cpp:5
std::shared_ptr< T > make_shared(Args &&... args)
T dynamic_pointer_cast(T... args)