1#include <catch2/catch_test_macros.hpp>
3#include <Core/Containers/MakeShared.hpp>
4#include <Dataflow/Core/DataflowGraph.hpp>
5#include <Dataflow/Core/Functionals/BinaryOpNode.hpp>
6#include <Dataflow/Core/Functionals/FilterNode.hpp>
7#include <Dataflow/Core/Functionals/FunctionNode.hpp>
8#include <Dataflow/Core/Functionals/ReduceNode.hpp>
9#include <Dataflow/Core/Functionals/TransformNode.hpp>
10#include <Dataflow/Core/Sinks/SinkNode.hpp>
11#include <Dataflow/Core/Sinks/Types.hpp>
12#include <Dataflow/Core/Sources/Types.hpp>
14using namespace Ra::Dataflow::Core;
17TEST_CASE(
"Dataflow/Core/GraphAsNode/Delta",
"[unittests][Dataflow][Core][Graph]" ) {
18 auto port_fatcory = PortFactory::getInstance();
19 port_fatcory->add_port_type<Scalar>();
25 auto b2 = gAsNode->add_node<Functionals::FunctionNode<Scalar>>(
"b2" );
26 b2->set_function( [](
const Scalar& b ) {
return b * b; } );
29 "4ac", [](
const Scalar& a,
const Scalar& c ) {
return 4_ra * a * c; } );
32 "b2-4ac", [](
const Scalar& x,
const Scalar& y ) {
return x - y; } );
34 auto forwardA = gAsNode->add_node<Functionals::FunctionNode<Scalar>>(
"a" );
35 auto forwardB = gAsNode->add_node<Functionals::FunctionNode<Scalar>>(
"b" );
36 auto forwardC = gAsNode->add_node<Functionals::FunctionNode<Scalar>>(
"c" );
38 b2minus4ac->port_out_result()->set_name(
"delta" );
40 REQUIRE( !gAsNode->input_node() );
41 REQUIRE( !gAsNode->output_node() );
43 gAsNode->add_input_output_nodes();
44 auto inputA = gAsNode->input_node()->add_output_port( forwardA->port_in_data().get() );
45 auto inputB = gAsNode->input_node()->add_output_port( forwardB->port_in_data().get() );
46 auto inputC = gAsNode->input_node()->add_output_port( forwardC->port_in_data().get() );
47 auto output = gAsNode->output_node()->add_input_port( b2minus4ac->port_out_result().get() );
49 REQUIRE( gAsNode->input_node() );
50 REQUIRE( gAsNode->output_node() );
51 REQUIRE( gAsNode->input_node()->outputs().size() == 3 );
52 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
53 REQUIRE( gAsNode->inputs().size() == 0 );
54 REQUIRE( gAsNode->outputs().size() == 0 );
56 gAsNode->input_node()->input_by_index( inputA )->set_name(
"a" );
57 gAsNode->input_node()->input_by_index( inputB )->set_name(
"b" );
58 gAsNode->input_node()->input_by_index( inputC )->set_name(
"c" );
59 gAsNode->output_node()->output_by_index( output )->set_name(
"delta" );
61 REQUIRE( gAsNode->add_link( forwardB->port_out_result(), b2->port_in_data() ) );
62 REQUIRE( gAsNode->add_link( forwardA->port_out_result(), fourAC->port_in_a() ) );
63 REQUIRE( gAsNode->add_link( forwardC->port_out_result(), fourAC->port_in_b() ) );
65 REQUIRE( gAsNode->add_link( b2->port_out_result(), b2minus4ac->port_in_a() ) );
66 REQUIRE( gAsNode->add_link( fourAC->port_out_result(), b2minus4ac->port_in_b() ) );
68 REQUIRE( gAsNode->compile() );
70 REQUIRE( gAsNode->inputs().size() == 3 );
71 REQUIRE( gAsNode->outputs().size() == 1 );
80 REQUIRE( g.add_node( gAsNode ) );
82 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( inputA ) ) );
83 REQUIRE( g.add_link( sourceNodeB->port_out_to().get(), gAsNode->input_by_index( inputB ) ) );
84 REQUIRE( g.add_link( sourceNodeC->port_out_to().get(), gAsNode->input_by_index( inputC ) ) );
85 REQUIRE( g.add_link( gAsNode->output_by_index( output ), resultNode->port_in_from().get() ) );
87 sourceNodeA->set_data( 1 );
91 REQUIRE( g.compile() );
92 REQUIRE( g.execute() );
94 auto& result = resultNode->data_reference();
95 REQUIRE( result == -8 );
98using PortIndex = Ra::Dataflow::Core::Node::PortIndex;
99using FunctionNode = Functionals::FunctionNode<Scalar>;
103TEST_CASE(
"Dataflow/Core/GraphAsNode/Forward",
"[unittests][Dataflow][Core][Graph]" ) {
105 auto port_fatcory = PortFactory::createInstance();
106 port_fatcory->add_port_type<Scalar>();
109 auto f = gAsNode->add_node<FunctionNode>(
"f" );
111 REQUIRE( !gAsNode->input_node() );
112 REQUIRE( !gAsNode->output_node() );
114 gAsNode->add_input_output_nodes();
116 REQUIRE( gAsNode->input_node() );
117 REQUIRE( gAsNode->output_node() );
119 REQUIRE( !gAsNode->can_link(
120 gAsNode->input_node(), PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
121 REQUIRE( !gAsNode->add_link(
122 gAsNode->input_node(), PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
125 REQUIRE( !gAsNode->can_link( gAsNode->input_node(), PortIndex { 10 }, f, PortIndex { 0 } ) );
126 REQUIRE( !gAsNode->can_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 10 } ) );
127 REQUIRE( gAsNode->can_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 0 } ) );
129 REQUIRE( !gAsNode->add_link( gAsNode->input_node(), PortIndex { 10 }, f, PortIndex { 0 } ) );
130 REQUIRE( !gAsNode->add_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 10 } ) );
131 REQUIRE( gAsNode->add_link( gAsNode->input_node(), PortIndex { 0 }, f, PortIndex { 0 } ) );
133 REQUIRE( !gAsNode->can_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
134 REQUIRE( !gAsNode->can_link( f, PortIndex { 1 }, gAsNode->output_node(), PortIndex { 0 } ) );
135 REQUIRE( gAsNode->can_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
137 REQUIRE( !gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
138 REQUIRE( !gAsNode->add_link( f, PortIndex { 1 }, gAsNode->output_node(), PortIndex { 0 } ) );
139 REQUIRE( gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 0 } ) );
141 REQUIRE( gAsNode->input_node() );
142 REQUIRE( gAsNode->output_node() );
143 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
144 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
145 REQUIRE( gAsNode->inputs().size() == 0 );
146 REQUIRE( gAsNode->outputs().size() == 0 );
148 gAsNode->generate_ports();
150 REQUIRE( gAsNode->inputs().size() == 1 );
151 REQUIRE( gAsNode->outputs().size() == 1 );
156 auto resultNode = g.add_node<
Sink>(
"r" );
158 REQUIRE( g.add_node( gAsNode ) );
159 REQUIRE( g.can_link( sourceNodeA, PortIndex { 0 }, gAsNode, PortIndex { 0 } ) );
160 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( 0 ) ) );
161 REQUIRE( g.add_link( gAsNode->output_by_index( 0 ), resultNode->port_in_from().get() ) );
163 sourceNodeA->set_data( 2 );
165 SECTION(
"Serialization" ) {
169 std::filesystem::create_directories( tmpdir );
170 REQUIRE( g.shouldBeSaved() );
171 g.saveToJson( tmpdir +
"/graph_as_node_io.json" );
172 REQUIRE( !g.shouldBeSaved() );
176 REQUIRE( g1.loadFromJson( tmpdir +
"/graph_as_node_io.json" ) );
180 REQUIRE( g1_sourceNodeA );
183 REQUIRE( g1_nodeGraph->display_name() ==
"graphAsNode" );
184 REQUIRE( g1_nodeGraph );
186 REQUIRE( g1_resultNode );
188 REQUIRE( g1_sourceNodeA->data() );
189 REQUIRE( *g1_sourceNodeA->data() == 2 );
193 SECTION(
"Remove unlinked" ) {
195 REQUIRE( g.compile() );
196 REQUIRE( g.execute() );
199 gAsNode->remove_unlinked_input_output_ports();
200 REQUIRE( gAsNode->input_node() );
201 REQUIRE( gAsNode->output_node() );
202 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
203 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
204 REQUIRE( gAsNode->inputs().size() == 1 );
205 REQUIRE( gAsNode->outputs().size() == 1 );
208 REQUIRE( gAsNode->remove_link( f, PortIndex { 0 } ) );
209 REQUIRE( gAsNode->add_link( gAsNode->input_node(), PortIndex { 1 }, f, PortIndex { 0 } ) );
210 REQUIRE( gAsNode->add_link( f, PortIndex { 0 }, gAsNode->output_node(), PortIndex { 1 } ) );
211 REQUIRE( gAsNode->remove_link( gAsNode->output_node(), PortIndex { 0 } ) );
214 gAsNode->remove_unlinked_input_output_ports();
216 REQUIRE( gAsNode->input_node() );
217 REQUIRE( gAsNode->output_node() );
218 REQUIRE( gAsNode->input_node()->outputs().size() == 2 );
219 REQUIRE( gAsNode->output_node()->inputs().size() == 2 );
220 REQUIRE( gAsNode->inputs().size() == 2 );
221 REQUIRE( gAsNode->outputs().size() == 2 );
224 REQUIRE( g.remove_link( gAsNode->input_node(), PortIndex { 0 } ) );
225 REQUIRE( g.remove_link( resultNode, PortIndex { 0 } ) );
227 gAsNode->remove_unlinked_input_output_ports();
230 REQUIRE( gAsNode->input_node() );
231 REQUIRE( gAsNode->output_node() );
232 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
233 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
234 REQUIRE( gAsNode->inputs().size() == 1 );
235 REQUIRE( gAsNode->outputs().size() == 1 );
238 REQUIRE( g.add_link( sourceNodeA->port_out_to().get(), gAsNode->input_by_index( 0 ) ) );
239 REQUIRE( g.add_link( gAsNode->output_by_index( 0 ), resultNode->port_in_from().get() ) );
241 gAsNode->remove_unlinked_input_output_ports();
244 REQUIRE( gAsNode->input_node() );
245 REQUIRE( gAsNode->output_node() );
246 REQUIRE( gAsNode->input_node()->outputs().size() == 1 );
247 REQUIRE( gAsNode->output_node()->inputs().size() == 1 );
248 REQUIRE( gAsNode->inputs().size() == 1 );
249 REQUIRE( gAsNode->outputs().size() == 1 );
Represent a set of connected nodes that define a Direct Acyclic Computational Graph Ownership of node...
virtual bool add_node(std::shared_ptr< Node > newNode)
Adds a node to the graph.
Apply a binary operation on its input.
Base class for nodes that will store the result of a computation graph.
Base class for nodes that will give access to some input data to the graph. This class can be used to...
void set_data(T data)
Set the data to be delivered by the node.
This namespace contains everything "low level", related to data, datastuctures, and computation.
std::shared_ptr< T > make_shared(Args &&... args)
T dynamic_pointer_cast(T... args)