Loading [MathJax]/extensions/TeX/AMSmath.js
Radium Engine  1.5.28
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
graph.cpp
1#include <catch2/catch_test_macros.hpp>
2
3#include <Dataflow/Core/DataflowGraph.hpp>
4#include <Dataflow/Core/Functionals/FunctionNode.hpp>
5#include <Dataflow/Core/Functionals/ReduceNode.hpp>
6#include <Dataflow/Core/Functionals/TransformNode.hpp>
7#include <Dataflow/Core/Functionals/Types.hpp>
8#include <Dataflow/Core/Node.hpp>
9#include <Dataflow/Core/NodeFactory.hpp>
10#include <Dataflow/Core/PortIn.hpp>
11#include <Dataflow/Core/PortOut.hpp>
12#include <Dataflow/Core/Sinks/Types.hpp>
13#include <Dataflow/Core/Sources/FunctionSource.hpp>
14#include <Dataflow/Core/Sources/SingleDataSourceNode.hpp>
15#include <Dataflow/Core/Sources/Types.hpp>
16
17#include <nlohmann/json.hpp>
18
19#include <algorithm>
20#include <memory>
21#include <string>
22#include <vector>
23
24using namespace Ra::Dataflow::Core;
25
26void inspectGraph( const DataflowGraph& g ) {
27 // Nodes of the graph
28 const auto& nodes = g.nodes();
29 std::cout << "Nodes of the graph " << g.instance_name() << " (" << nodes.size() << ") :\n";
30 for ( const auto& n : nodes ) {
31 std::cout << "\t\"" << n->instance_name() << "\" of type \"" << n->model_name() << "\"\n";
32 // Inspect input, output and interfaces of the node
33 std::cout << "\t\tInput ports :\n";
34 for ( const auto& p : n->inputs() ) {
35 std::cout << "\t\t\t\"" << p->name() << "\" with type " << p->port_typename();
36 if ( p->is_linked() ) {
37 std::cout << " linked to " << p->link()->node()->display_name() << " "
38 << p->link()->name();
39 }
40 std::cout << "\n";
41 }
42 std::cout << "\t\tOutput ports :\n";
43 for ( const auto& p : n->outputs() ) {
44 std::cout << "\t\t\t\"" << p->name() << "\" with type " << p->port_typename();
45 std::cout << "\n";
46 }
47 }
48
49 // Nodes by level after the compilation
50 if ( g.is_compiled() ) {
51 auto& cn = g.nodes_by_level();
52 std::cout << "Nodes of the graph, sorted by level after compiling the graph :\n";
53 for ( size_t i = 0; i < cn.size(); ++i ) {
54 std::cout << "\tLevel " << i << " :\n";
55 for ( const auto n : cn[i] ) {
56 std::cout << "\t\t\"" << n->instance_name() << "\"\n";
57 }
58 }
59 }
60}
61using PortIndex = Ra::Dataflow::Core::Node::PortIndex;
62
63TEST_CASE( "Dataflow/Core/Graph/Json", "[unittests][Dataflow][Core][Graph]" ) {
64 DataflowGraph g( "Test Graph" );
65 SECTION( "not a json" ) {
66 auto result = g.loadFromJson( "data/Dataflow/NotAJsonFile.json" );
67 REQUIRE( !result );
68 }
69 SECTION( "loading empty graph" ) {
70 nlohmann::json emptyJson = {};
71 auto result = g.fromJson( emptyJson );
72 REQUIRE( result );
73 }
74 SECTION( "missing instance" ) {
75 nlohmann::json noId = { { "model", { "name", "Core DataflowGraph" } } };
76 auto result = g.fromJson( noId );
77 REQUIRE( !result );
78 }
79 SECTION( "missing model" ) {
80 nlohmann::json noModel = { { "instance", "No model in this node" } };
81 auto result = g.fromJson( noModel );
82 REQUIRE( !result );
83 }
84 SECTION( "missing instance data -> loads an empty graph" ) {
85 nlohmann::json noGraph = { { "instance", "Missing instance data for model" },
86 { "model", { "name", "Core DataflowGraph" } } };
87 auto result = g.fromJson( noGraph );
88 REQUIRE( result );
89 }
90 SECTION( "trying to instance an unknown node type" ) {
91 nlohmann::json NotANode = {
92 { "instance", "graph with unknown node" },
93 { "model",
94 { { "name", "Core DataflowGraph" },
95 { "graph",
96 { { "nodes",
97 { { { "instance", "NotANode" },
98 { "model", { { "name", "NotANode" } } } } } } } } } } };
99 auto result = g.fromJson( NotANode );
100 REQUIRE( !result );
101 }
102 SECTION( "trying to instance an unknown node type" ) {
103 nlohmann::json NoModelName = {
104 { "instance", "graph with missing node model information" },
105 { "model",
106 { { "name", "Core DataflowGraph" },
107 { "graph",
108 { { "nodes",
109 { { { "instance", "Unknown model" },
110 { "model", { { "extra", "NotaTypeName" } } } } } } } } } } };
111 auto result = g.fromJson( NoModelName );
112 REQUIRE( !result );
113 }
114 SECTION( "trying to instance an unknown node type" ) {
115 nlohmann::json noInstanceIdentification = {
116 { "instance", "graph with missing node model information" },
117 { "model",
118 { { "name", "Core DataflowGraph" },
119 { "graph",
120 { { "nodes", { { { "model", { { "name", "Source<Scalar>" } } } } } } } } } } };
121 auto result = g.fromJson( noInstanceIdentification );
122 REQUIRE( !result );
123 }
124 SECTION( "errors in the connection description" ) {
125 nlohmann::json reusingNodeIdentification = {
126 { "instance", "graph with wrong connection" },
127 { "model",
128 { { "name", "Core DataflowGraph" },
129 { "graph",
130 { { "nodes",
131 { { { "instance", "Source" }, { "model", { { "name", "Source<Scalar>" } } } },
132 { { "instance", "Source" }, { "model", { { "name", "Sink<int>" } } } } } },
133 { "connections", { { { "out_node", "wrongId" } } } } } } } } };
134 auto result = g.fromJson( reusingNodeIdentification );
135 REQUIRE( !result );
136 }
137 SECTION( "wrong connection 0" ) {
138 nlohmann::json reusingNodeIdentification = {
139 { "instance", "graph with wrong connection" },
140 { "model",
141 { { "name", "Core DataflowGraph" },
142 { "graph",
143 { { "nodes",
144 { { { "model", { { "name", "Source<Scalar>" } } } },
145 { { "model", { { "name", "Sink<int>" } } } } } },
146 { "connections", { { { "out_node", "wrongId" } } } } } } } } };
147 auto result = g.fromJson( reusingNodeIdentification );
148 REQUIRE( !result );
149 }
150 SECTION( "wrong connection 1" ) {
151 nlohmann::json wrongConnection = {
152 { "instance", "graph with wrong connection" },
153 { "model",
154 { { "name", "Core DataflowGraph" },
155 { "graph",
156 { { "nodes",
157 { { { "instance", "SourceScalar" },
158 { "model", { { "name", "Source<Scalar>" } } } },
159 { { "instance", "SinkInt" }, { "model", { { "name", "Sink<int>" } } } } } },
160 { "connections", { { { "out_node", "wrongId" } } } } } } } } };
161 auto result = g.fromJson( wrongConnection );
162 REQUIRE( !result );
163 }
164 SECTION( "wrong connection 2" ) {
165 nlohmann::json wrongConnection = {
166 { "instance", "Test Graph Inline" },
167 { "model",
168 { { "name", "Core DataflowGraph" },
169 { "graph",
170 { { "nodes",
171 { { { "instance", "SourceScalar" },
172 { "model", { { "name", "Source<Scalar>" } } } },
173 { { "instance", "SinkInt" }, { "model", { { "name", "Sink<int>" } } } } } },
174 { "connections",
175 { { { "out_node", "SourceScalar" }, { "out_index", 2 } } } } } } } } };
176 auto result = g.fromJson( wrongConnection );
177 REQUIRE( !result );
178 }
179 SECTION( "wrong connection 3" ) {
180 nlohmann::json wrongConnection = {
181 { "instance", "Test Graph Inline" },
182 { "model",
183 { { "name", "Core DataflowGraph" },
184 { "graph",
185 { { "nodes",
186 { { { "instance", "SourceScalar" },
187 { "model", { { "name", "Source<Scalar>" } } } },
188 { { "instance", "SinkInt" }, { "model", { { "name", "Sink<int>" } } } } } },
189 { "connections",
190 { { { "out_node", "SourceScalar" },
191 { "out_index", 0 },
192 { "in_node", "Sink" },
193 { "in_port", "from" } } } } } } } } };
194 auto result = g.fromJson( wrongConnection );
195 REQUIRE( !result );
196 }
197 SECTION( "wrong connection 4" ) {
198 nlohmann::json wrongConnection = {
199 { "instance", "Test Graph Inline" },
200 { "model",
201 { { "name", "Core DataflowGraph" },
202 { "graph",
203 { { "nodes",
204 { { { "instance", "SourceScalar" },
205 { "model", { { "name", "Source<Scalar>" } } } },
206 { { "instance", "SinkInt" }, { "model", { { "name", "Sink<int>" } } } } } },
207 { "connections",
208 { { { "out_node", "SourceScalar" },
209 { "out_index", 0 },
210 { "in_node", "SinkInt" },
211 { "in_port", "from" } } } } } } } } };
212 auto result = g.fromJson( wrongConnection );
213 REQUIRE( !result );
214 }
215 SECTION( "correct graph" ) {
216 nlohmann::json goodSimpleGraph = {
217 { "instance", "Test Graph Inline" },
218 { "model",
219 { { "name", "Core DataflowGraph" },
220 { "graph",
221 { { "nodes",
222 { { { "instance", "SourceScalar" },
223 { "model", { { "name", "Source<Scalar>" } } } },
224 { { "instance", "SinkScalar" },
225 { "model", { { "name", "Sink<Scalar>" } } } } } },
226 { "connections",
227 { { { "out_node", "SourceScalar" },
228 { "out_port", "to" },
229 { "in_node", "SinkScalar" },
230 { "in_index", 0 } } } } } } } } };
231
232 REQUIRE( g.fromJson( goodSimpleGraph ) );
233
234 // trying to add a duplicated node
235 auto duplicatedNodeName =
237 REQUIRE( !g.add_node( duplicatedNodeName ) );
238 REQUIRE( !g.add_node<Sources::SingleDataSourceNode<Scalar>>( "SourceScalar" ) );
239
240 // get unknown node
241 auto sinkScalarNode = g.node( "Sink" );
242 REQUIRE( sinkScalarNode == nullptr );
243 // get known node
244 sinkScalarNode = g.node( "SinkScalar" );
245 REQUIRE( sinkScalarNode != nullptr );
246
247 auto sourceScalarNode = g.node( "SourceScalar" );
248 REQUIRE( sourceScalarNode != nullptr );
249
250 auto sourceIntNode = std::make_shared<Sources::IntSource>( "SourceInt" );
251 auto sinkIntNode = std::make_shared<Sinks::IntSink>( "SinkInt" );
252 // node not found
253 REQUIRE( !g.remove_link( sinkIntNode, "from" ) );
254
255 REQUIRE( !g.add_link( sourceIntNode, "to", sinkIntNode, "from" ) );
256 REQUIRE( !g.can_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 0 } ) );
257 REQUIRE( !g.add_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 0 } ) );
258
259 REQUIRE( g.add_node( sourceIntNode ) );
260 REQUIRE( !g.add_link( sourceIntNode, "to", sinkIntNode, "from" ) );
261 REQUIRE( !g.can_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 0 } ) );
262 REQUIRE( !g.add_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 0 } ) );
263
264 REQUIRE( g.add_node( sinkIntNode ) );
265
266 // output port of "in" node not found
267 // input port of "to" node not found
268 REQUIRE( !g.add_link( sourceIntNode, "out", sinkIntNode, "from" ) );
269 REQUIRE( !g.add_link( sourceIntNode, PortIndex { 10 }, sinkIntNode, PortIndex { 0 } ) );
270 REQUIRE( !g.add_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 10 } ) );
271 REQUIRE( !g.can_link( sourceIntNode, PortIndex { 10 }, sinkIntNode, PortIndex { 0 } ) );
272 REQUIRE( !g.can_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 10 } ) );
273 REQUIRE( !g.add_link( sourceIntNode, "to", sinkIntNode, "in" ) );
274
275 // link OK
276 REQUIRE( g.add_link( sourceIntNode, "to", sinkIntNode, "from" ) );
277
278 // from port of "to" node already linked
279 REQUIRE( !g.add_link( sourceIntNode, "to", sinkIntNode, "from" ) );
280
281 // type mismatch
282 REQUIRE( !g.add_link( sourceIntNode, "to", sinkScalarNode, "from" ) );
283
284 // protect the graph to prevent link removal
286 REQUIRE( g.nodesAndLinksProtection() );
287 // unable to remove links from protected graph ...
288 REQUIRE( !g.remove_link( sinkIntNode, "from" ) );
290 REQUIRE( !g.nodesAndLinksProtection() );
291 // remove link OK
292
293 REQUIRE( g.remove_link( sinkIntNode, "from" ) );
294 // input port not found to remove its link
295 REQUIRE( !g.remove_link( sinkIntNode, "in" ) );
296 REQUIRE( !g.remove_link( sinkIntNode, PortIndex { 0 } ) );
297 REQUIRE( g.add_link( sourceIntNode, PortIndex { 0 }, sinkIntNode, PortIndex { 0 } ) );
298 REQUIRE( !g.remove_link( sinkIntNode, PortIndex { 10 } ) );
299 REQUIRE( g.remove_link( sinkIntNode, PortIndex { 0 } ) );
300
301 // compile the graph
302 REQUIRE( g.compile() );
303 REQUIRE( g.is_compiled() );
304 inspectGraph( g );
305
306 // clear the graph
307 g.clear_nodes();
308
309 // Nodes can't be found
310 auto nullNode = g.node( "SourceInt" );
311 REQUIRE( nullNode == nullptr );
312 nullNode = g.node( "SinkInt" );
313 REQUIRE( nullNode == nullptr );
314 // Nodes can't be found
315 nullNode = g.node( "SourceScalar" );
316 REQUIRE( nullNode == nullptr );
317 nullNode = g.node( "SinkScalar" );
318 REQUIRE( nullNode == nullptr );
319 }
320 // destroy everything
321 g.destroy();
322}
323
324TEST_CASE( "Dataflow/Core/Graph/Node failed execution", "[unittests][Dataflow][Core][Graph]" ) {
325 DataflowGraph g( "Test Graph" );
326 auto sourceIntNode = g.add_node<Sources::IntSource>( "SourceInt" );
327 auto sinkIntNode = g.add_node<Sinks::IntSink>( "SinkInt" );
328 class FailFunction : public Functionals::TransformInt
329 {
330 public:
331 explicit FailFunction( const std::string& instanceName ) : FunctionNode( instanceName ) {}
332 bool execute() { return false; }
333 };
334 auto failNode = g.add_node<FailFunction>( "FailNode" );
335
336 REQUIRE( g.add_link( sourceIntNode, "to", failNode, "data" ) );
337 REQUIRE( g.add_link( failNode, "result", sinkIntNode, "from" ) );
338 REQUIRE( g.compile() );
339 REQUIRE( !g.execute() );
340}
341
342TEST_CASE( "Dataflow/Core/Graph/Inspection of a graph", "[unittests][Dataflow][Core][Graph]" ) {
343 auto coreFactory = NodeFactoriesManager::default_factory();
344
345 using namespace Ra::Dataflow::Core;
346
347 // add some nodes to factory
354 REGISTER_TYPE_TO_FACTORY( coreFactory, ScalarVectorSource, Sources );
355 REGISTER_TYPE_TO_FACTORY( coreFactory, ScalarFilterSource, Sources );
356 REGISTER_TYPE_TO_FACTORY( coreFactory, ScalarFunctionSource, Sources );
357 REGISTER_TYPE_TO_FACTORY( coreFactory, ScalarPredicateSource, Sources );
358 REGISTER_TYPE_TO_FACTORY( coreFactory, ReduceNode, Functionals );
359 REGISTER_TYPE_TO_FACTORY( coreFactory, TransformNode, Functionals );
360
361 std::cout << "Loading graph data/Dataflow/ExampleGraph.json\n";
362
363 REQUIRE( !DataflowGraph::loadGraphFromJsonFile( "data/Dataflow/NotAJsonFile.json" ) );
364 REQUIRE( !DataflowGraph::loadGraphFromJsonFile( "data/Dataflow/InvalidGraph.json" ) );
365 REQUIRE( !DataflowGraph::loadGraphFromJsonFile( "data/Dataflow/UnknownTypeGraph.json" ) );
366 REQUIRE( !DataflowGraph::loadGraphFromJsonFile( "data/Dataflow/Node.json" ) );
367
368 auto g = DataflowGraph::loadGraphFromJsonFile( "data/Dataflow/ExampleGraph.json" );
369 REQUIRE( g );
370 // Factories used by the graph
371 const auto& nodes = g->nodes();
372 REQUIRE( nodes.size() == g->node_count() );
373
374 REQUIRE( g->compile() );
375 REQUIRE( g->is_compiled() );
376 // Prints the graph content
377 inspectGraph( *g );
378 g->needs_recompile();
379 REQUIRE( !g->is_compiled() );
380
381 // removing the boolean sink from the graph
382 auto n = g->node( "validation value" );
383 auto useCount = n.use_count();
384 REQUIRE( n->instance_name() == "validation value" );
385
386 REQUIRE( g->remove_node( n ) );
387 REQUIRE( n );
388 REQUIRE( n.use_count() == useCount - 1 );
389
390 REQUIRE( g->compile() );
391
392 // Simplified graph after compilation
393 auto& cn = g->nodes_by_level();
394 // the source "Validator" is no more in level 0 as it is not reachable from a sink in the
395 // graph.
396 auto found = std::find_if( cn[0].begin(), cn[0].end(), []( const auto& nn ) {
397 return nn->instance_name() == "Validator";
398 } );
399 REQUIRE( found == cn[0].end() );
400
401 // removing the source "Validator"
402 n = g->node( "Validator" );
403 REQUIRE( n->instance_name() == "Validator" );
404 // protect the graph to prevent node removal
406 REQUIRE( !g->remove_node( n ) );
407 g->setNodesAndLinksProtection( false );
408 REQUIRE( g->remove_node( n ) );
409
410 std::cout << "####### Graph after sink and source removal\n";
411 inspectGraph( *g );
412}
413
415using namespace Ra::Dataflow::Core;
416template <typename DataType_a, typename DataType_b = DataType_a, typename DataType_r = DataType_a>
418createGraph(
419 const std::string& name,
422 auto g = new DataflowGraph { name };
423
425 g->add_node( source_a );
426 auto a = g->input_node_port( "a", "from" );
427 REQUIRE( a->node() == source_a.get() );
428
430 g->add_node( source_b );
431 auto b = g->input_node_port( "b", "from" );
432 REQUIRE( b->node() == source_b.get() );
433
435 g->add_node( sink );
436 auto r = g->output_node_port( "r", "data" );
437 REQUIRE( r->node() == sink.get() );
438
439 auto op = std::make_shared<TestNode>( "operator", f );
440 // op->setOperator( f );
441 g->add_node( op );
442
443 REQUIRE( g->add_link( source_a, "to", op, "a" ) );
444 REQUIRE( g->add_link( op, "result", sink, "from" ) );
445 REQUIRE( !g->compile() );
446 // this will not execute the graph as it does not compile
447 REQUIRE( !g->execute() );
448 REQUIRE( !g->is_compiled() );
449 // add missing link
450 REQUIRE( g->add_link( source_b, "to", op, "b" ) );
451
452 return { g, a, b, r };
453}
454
455TEST_CASE( "Dataflow/Core/Nodes", "[unittests][Dataflow][Core][Nodes]" ) {
456 SECTION( "Operations on Scalar" ) {
457 using DataType = Scalar;
459 typename TestNode::BinaryOperator add = []( typename TestNode::Arg1_type a,
460 typename TestNode::Arg2_type b ) ->
461 typename TestNode::Res_type { return a + b; };
462
463 auto [g, a, b, r] = createGraph<DataType>( "test scalar binary op", add );
464
465 DataType x { 1_ra };
466
467 a->set_default_value( x );
468 REQUIRE( a->data<DataType>() == x );
469
470 DataType y { 2_ra };
471 b->set_default_value( y );
472 REQUIRE( b->data<DataType>() == y );
473
474 // As graph was modified since last compilation, this will recompile the graph
475 g->execute();
476
477 auto& z = r->data<DataType>();
478 REQUIRE( z == x + y );
479 // could not get data as other type.
480 REQUIRE_THROWS( r->data<int>() );
481
482 std::cout << x << " + " << y << " == " << z << "\n";
483
484 g->destroy();
485 delete g;
486 }
487
488 SECTION( "Operations on Vectors" ) {
489 using DataType = Ra::Core::Vector3;
491 typename TestNode::BinaryOperator add = []( typename TestNode::Arg1_type a,
492 typename TestNode::Arg2_type b ) ->
493 typename TestNode::Res_type { return a + b; };
494
495 auto [g, a, b, r] = createGraph<DataType>( "test Vector3 binary op", add );
496
497 DataType x { 1_ra, 2_ra, 3_ra };
498 a->set_default_value( x );
499 REQUIRE( a->data<DataType>() == x );
500
501 DataType y { 3_ra, 2_ra, 1_ra };
502 b->set_default_value( y );
503 REQUIRE( b->data<DataType>() == y );
504
505 g->execute();
506
507 auto& z = r->data<DataType>();
508 REQUIRE( z == x + y );
509
510 std::cout << "[" << x.transpose() << "] + [" << y.transpose() << "] == [" << z.transpose()
511 << "]\n";
512
513 g->destroy();
514 delete g;
515 }
516
517 SECTION( "Operations on VectorArrays" ) {
520 typename TestNode::BinaryOperator add = []( typename TestNode::Arg1_type a,
521 typename TestNode::Arg2_type b ) ->
522 typename TestNode::Res_type { return a + b; };
523
524 auto [g, a, b, r] = createGraph<DataType>( "test Vector3 binary op", add );
525
526 DataType x { { 1_ra, 2_ra }, { 3_ra, 4_ra } };
527 a->set_default_value( x );
528 REQUIRE( a->data<DataType>() == x );
529
530 DataType y { { 5_ra, 6_ra }, { 7_ra, 8_ra } };
531 b->set_default_value( y );
532 REQUIRE( b->data<DataType>() == y );
533
534 g->execute();
535
536 auto& z = r->data<DataType>();
537 for ( size_t i = 0; i < z.size(); i++ ) {
538 REQUIRE( z[i] == x[i] + y[i] );
539 }
540
541 std::cout << "{ ";
542 for ( const auto& t : x ) {
543 std::cout << "[" << t.transpose() << "] ";
544 }
545 std::cout << "} + { ";
546 for ( const auto& t : y ) {
547 std::cout << "[" << t.transpose() << "] ";
548 }
549 std::cout << "} = { ";
550 for ( const auto& t : z ) {
551 std::cout << "[" << t.transpose() << "] ";
552 }
553 std::cout << "}\n";
554
555 g->destroy();
556 delete g;
557 }
558
559 SECTION( "Operations between VectorArray and Scalar" ) {
561 using DataType_b = Scalar;
562 // How to do this ? Eigen generates an error due to align allocation
563 // using DataType_r = Ra::Core::VectorArray< decltype( std::declval<Ra::Core::Vector2>() *
564 // std::declval<Scalar>() ) >;
567 typename TestNode::BinaryOperator op = []( typename TestNode::Arg1_type a,
568 typename TestNode::Arg2_type b ) ->
569 typename TestNode::Res_type { return a * b; };
570 auto [g, a, b, r] = createGraph<DataType_a, DataType_b, DataType_r>(
571 "test Vector2 x Scalar binary op", op );
572
573 DataType_a x { { 1_ra, 2_ra }, { 3_ra, 4_ra } };
574 a->set_default_value( x );
575 REQUIRE( a->data<DataType_a>() == x );
576
577 DataType_b y { 5_ra };
578 b->set_default_value( y );
579 REQUIRE( b->data<DataType_b>() == y );
580
581 g->execute();
582
583 auto& z = r->data<DataType_r>();
584 for ( size_t i = 0; i < z.size(); i++ ) {
585 REQUIRE( z[i] == x[i] * y );
586 }
587
588 std::cout << "{ ";
589 for ( const auto& t : x ) {
590 std::cout << "[" << t.transpose() << "] ";
591 }
592 std::cout << "} * " << y << " = { ";
593 for ( const auto& t : z ) {
594 std::cout << "[" << t.transpose() << "] ";
595 }
596 std::cout << "}\n";
597
598 // change operator
599 auto opNode = std::dynamic_pointer_cast<TestNode>( g->node( "operator" ) );
600 REQUIRE( opNode != nullptr );
601 if ( opNode ) {
602 typename TestNode::BinaryOperator f = []( typename TestNode::Arg1_type arg1,
603 typename TestNode::Arg2_type arg2 ) ->
604 typename TestNode::Res_type { return arg1 / arg2; };
605 opNode->set_operator( f );
606 }
607 g->execute();
608
609 for ( size_t i = 0; i < z.size(); i++ ) {
610 REQUIRE( z[i] == x[i] / y );
611 }
612
613 std::cout << "{ ";
614 for ( const auto& t : x ) {
615 std::cout << "[" << t.transpose() << "] ";
616 }
617 std::cout << "} / " << y << " = { ";
618 for ( const auto& t : z ) {
619 std::cout << "[" << t.transpose() << "] ";
620 }
621 std::cout << "}\n";
622 g->destroy();
623 delete g;
624 }
625
626 SECTION( "Operations between Scalar and VectorArray" ) {
627 using namespace Ra::Dataflow::Core;
628 using DataType_a = Scalar;
632 typename TestNode::BinaryOperator op = []( typename TestNode::Arg1_type a,
633 typename TestNode::Arg2_type b ) ->
634 typename TestNode::Res_type { return a * b; };
635 auto [g, a, b, r] = createGraph<DataType_a, DataType_b, DataType_r>(
636 "test Vector2 x Scalar binary op", op );
637
638 DataType_a x { 4_ra };
639 a->set_default_value( x );
640 REQUIRE( a->data<DataType_a>() == x );
641
642 DataType_b y { { 1_ra, 2_ra }, { 3_ra, 4_ra } };
643 b->set_default_value( y );
644 REQUIRE( b->data<DataType_b>() == y );
645
646 g->execute();
647
648 auto& z = r->data<DataType_r>();
649 for ( size_t i = 0; i < z.size(); i++ ) {
650 REQUIRE( z[i] == x * y[i] );
651 }
652
653 std::cout << x << " * { ";
654 for ( const auto& t : y ) {
655 std::cout << "[" << t.transpose() << "] ";
656 }
657 std::cout << "} = { ";
658 for ( const auto& t : z ) {
659 std::cout << "[" << t.transpose() << "] ";
660 }
661 std::cout << "}\n";
662 g->destroy();
663 delete g;
664 }
665
666 SECTION( "Transform/reduce/filter/test" ) {
668 auto g = new DataflowGraph( "Complex graph" );
669 using VectorType = Ra::Core::VectorArray<Scalar>;
670
671 // Source of a vector of Scalar : random vector
673
674 // Source of an operator on scalars : f(x) = 2*x
676 DoubleFunction doubleMe = []( const Scalar& x ) -> Scalar { return 2_ra * x; };
678 nodeD->set_data( doubleMe );
679
680 // Source of a Scalar : mean neutral element 0_ra
681 auto nodeN = std::make_shared<Sources::ScalarSource>( "n" );
682 nodeN->set_data( 0_ra );
683
684 // Source of a reduction operator : compute the mean using Welford online algo
686 struct MeanOperator {
687 size_t n { 0 };
688 Scalar operator()( const Scalar& m, const Scalar& x ) {
689 return m + ( ( x - m ) / ( ++n ) );
690 }
691 };
692 auto nodeM = std::make_shared<ReduceOperator>( "m" );
693 ReduceOperator::function_type m = MeanOperator();
694
695 // Reduce node : will compute the mean
696 using MeanCalculator = Functionals::ReduceNode<VectorType>;
697 auto meanCalculator = std::make_shared<MeanCalculator>( "mean" );
698
699 // Sink for the mean
700 auto nodeR = std::make_shared<Sinks::ScalarSink>( "r" );
701
702 // Transform operator, will double the vectors' values
704
705 // Will compute the mean on the doubled vector
706 auto doubleMeanCalculator = std::make_shared<MeanCalculator>( "double mean" );
707
708 // Sink for the double mean
709 auto nodeRD = std::make_shared<Sinks::ScalarSink>( "rd" );
710
711 // Source for a comparison functor , eg f(x, y) -> 2*x == y
712 auto nodePred = std::make_shared<Sources::ScalarBinaryPredicateSource>( "predicate" );
713 Sources::ScalarBinaryPredicateSource::function_type predicate =
714 []( const Scalar& a, const Scalar& b ) -> bool { return 2_ra * a == b; };
715 nodePred->set_data( predicate );
716
717 // Boolean sink for the validation result
718 auto sinkB = std::make_shared<Sinks::BooleanSink>( "test" );
719
720 // Node for coparing the results of the computation graph
721 auto validator =
723
724 REQUIRE( g->add_node( nodeS ) );
725 REQUIRE( g->add_node( nodeD ) );
726 REQUIRE( g->add_node( nodeN ) );
727 REQUIRE( g->add_node( nodeM ) );
728 REQUIRE( g->add_node( nodeR ) );
729 REQUIRE( g->add_node( meanCalculator ) );
730 REQUIRE( g->add_node( doubleMeanCalculator ) );
731 REQUIRE( g->add_node( nodeT ) );
732 REQUIRE( g->add_node( nodeRD ) );
733
734 REQUIRE( g->add_link( nodeS, "to", meanCalculator, "data" ) );
735 REQUIRE( g->add_link( nodeM, "to", meanCalculator, "op" ) );
736 REQUIRE( g->add_link( nodeN, "to", meanCalculator, "init" ) );
737 REQUIRE( g->add_link( meanCalculator, "result", nodeR, "from" ) );
738 REQUIRE( g->add_link( nodeS, "to", nodeT, "data" ) );
739 REQUIRE( g->add_link( nodeD, "to", nodeT, "op" ) );
740 REQUIRE( g->add_link( nodeT, "result", doubleMeanCalculator, "data" ) );
741 REQUIRE( g->add_link( doubleMeanCalculator, "result", nodeRD, "from" ) );
742 REQUIRE( g->add_link( nodeM, "to", doubleMeanCalculator, "op" ) );
743
744 REQUIRE( g->add_node( nodePred ) );
745 REQUIRE( g->add_node( sinkB ) );
746 REQUIRE( g->add_node( validator ) );
747 REQUIRE( g->add_link( meanCalculator, "result", validator, "a" ) );
748 REQUIRE( g->add_link( doubleMeanCalculator, "result", validator, "b" ) );
749 REQUIRE( g->add_link( nodePred, "to", validator, "op" ) );
750 REQUIRE( g->add_link( validator, "result", sinkB, "from" ) );
751
752 auto input = g->input_node_port( "s", "from" );
753 auto output = g->output_node_port( "r", "data" );
754 auto outputD = g->output_node_port( "rd", "data" );
755 auto outputB = g->output_node_port( "test", "data" );
756 auto inputR = g->input_node_port( "m", "from" );
757 REQUIRE( inputR );
758
759 // Inspect the graph interface : inputs and outputs port
760
761 REQUIRE( g->compile() );
762
763 // Set input/ouput data
764 VectorType test;
765
766 test.reserve( 10 );
767 std::mt19937 gen( 0 );
768 std::uniform_real_distribution<> dis( 0.0, 1.0 );
769 // Fill the vector with random numbers between 0 and 1
770 for ( size_t n = 0; n < test.capacity(); ++n ) {
771 test.push_back( dis( gen ) );
772 }
773 input->set_default_value( test );
774
775 // No need to do this as mean operator source has a copy of a functor
776 ReduceOperator::function_type m1 = MeanOperator();
777 inputR->set_default_value( m1 );
778
779 g->execute();
780
781 auto& result = output->data<Scalar>();
782 auto& resultD = outputD->data<Scalar>();
783 auto& resultB = outputB->data<bool>();
784
785 std::cout << "Computed mean ( ref ): " << result << "\n";
786 std::cout << "Computed mean ( tra ): " << resultD << "\n";
788 std::cout << "Ratio ( expected 2 ): " << resultD / result << " -- validator --> "
789 << resultB << "\n";
790
791 std::cout << '\n';
792
793 REQUIRE( resultD / result == 2_ra );
794 REQUIRE( resultB );
795 // uncomment this if you want to edit the generated graph with GraphEditor
796 // g->saveToJson( "Transform-reduce.json" );
797 g->destroy();
798 delete g;
800 }
801}
T boolalpha(T... args)
This class implements ContainerIntrospectionInterface for AlignedStdVector.
Represent a set of connected nodes that define a Direct Acyclic Computational Graph Ownership of node...
bool add_link(const std::shared_ptr< Node > &nodeFrom, const std::string &nodeFromOutputName, const std::shared_ptr< Node > &nodeTo, const std::string &nodeToInputName)
Connects two nodes of the graph.
bool loadFromJson(const std::string &jsonFilePath)
Loads nodes and links from a JSON file.
void setNodesAndLinksProtection(bool on)
protect nodes and links from deletion.
virtual bool add_node(std::shared_ptr< Node > newNode)
Adds a node to the graph.
Node::PortBaseInRawPtr input_node_port(const std::string &nodeName, const std::string &portName)
Gets an input port form a node of the graph.
size_t node_count() const
Gets the number of nodes.
const std::vector< std::shared_ptr< Node > > & nodes() const
Get the vector of all the nodes on the graph.
virtual void clear_nodes()
Deletes all nodes from the render graph.
bool compile() override
Compile the graph to check its validity and simplify it.
void destroy() override
Delete the node's content.
bool remove_link(std::shared_ptr< Node > node, const std::string &nodeInputName)
Removes the link connected to a node's input port.
std::shared_ptr< Node > node(const std::string &instanceNameNode) const
bool is_compiled() const
Test if the graph is compiled.
Node::PortBaseOutRawPtr output_node_port(const std::string &nodeName, const std::string &portName)
Gets an output port from a node of the graph.
virtual bool remove_node(std::shared_ptr< Node > node)
Removes a node from the graph.
const std::vector< std::vector< Node * > > & nodes_by_level() const
Gets the nodes ordered by level (after compilation)
bool can_link(const std::shared_ptr< Node > &nodeFrom, Node::PortIndex portOutIdx, const std::shared_ptr< Node > &nodeTo, Node::PortIndex portInIdx) const
bool nodesAndLinksProtection() const
get the protection status protect nodes and links from deletion
static std::shared_ptr< DataflowGraph > loadGraphFromJsonFile(const std::string &filename)
Load a graph from the given file.
bool execute() override
Executes the node.
Apply a binary operation on its input.
Reduce an iterable collection using a given operator.
Transform an iterable collection.
const std::string & instance_name() const
Gets the instance name of the node.
Definition Node.hpp:447
bool fromJson(const nlohmann::json &data)
Unserialized the content of the node.
Definition Node.cpp:15
Node that deliver a std::function<R( Args... )}}>
Base class for nodes that will give access to some input data to the graph. This class can be used to...
T find_if(T... args)
T make_shared(T... args)
Quaternion add(const Quaternion &q1, const Quaternion &q2)
Returns the sum of two quaternions.
auto default_factory() -> NodeFactorySet::mapped_type
Gets the "default" factory for nodes exported by the Core dataflow library.
T dynamic_pointer_cast(T... args)
T size(T... args)