Loading [MathJax]/extensions/TeX/AMSsymbols.js
Radium Engine  1.5.28
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
customnodes.cpp
1
4#include <catch2/catch_test_macros.hpp>
5
6#include <string>
7
8#include <iostream>
9
10#include <Core/Utils/StdFilesystem.hpp>
11
12#include <Dataflow/Core/DataflowGraph.hpp>
13#include <Dataflow/Core/Functionals/Types.hpp>
14#include <Dataflow/Core/Sinks/Types.hpp>
15#include <Dataflow/Core/Sources/Types.hpp>
16
17using namespace Ra::Dataflow::Core;
18
19namespace Customs {
20using CustomStringSource = Sources::SingleDataSourceNode<std::string>;
21using CustomStringSink = Sinks::SinkNode<std::string>;
22
24
35template <class T>
36class FilterSelector final : public Node
37{
38 public:
39 using function_type = std::function<bool( const T& )>;
40
41 explicit FilterSelector( const std::string& name ) : FilterSelector( name, node_typename() ) {}
42
43 bool execute() override {
44 // since init with default value, always has_data
45 REQUIRE( m_portName->has_data() );
46 m_nameOut->set_data( &m_portName->data() );
47 m_currentFunction = m_functions.at( m_portName->data() );
48 return true;
49 }
50
51 protected:
52 bool fromJsonInternal( const nlohmann::json& data ) override {
53 // no need to have fallback, since json is ok.
54 REQUIRE( data.contains( "operator" ) );
55 m_portName->set_default_value( "true" );
56 REQUIRE( data.contains( "threshold" ) );
57 m_portThreshold->set_default_value( data["threshold"] );
58 return true;
59 }
60
61 void toJsonInternal( nlohmann::json& data ) const override {
62 data["operator"] = m_portName->data();
63 data["threshold"] = m_portThreshold->data();
64 }
65
66 public:
67 static const std::string& node_typename() {
68 static std::string demangledTypeName =
69 std::string { "FilterSelector<" } + Ra::Core::Utils::simplifiedDemangledType<T>() + ">";
70 return demangledTypeName;
71 }
72
73 private:
74 FilterSelector( const std::string& instanceName, const std::string& typeName ) :
75 Node( instanceName, typeName ) {}
78 { "true", []( const T& ) { return true; } },
79 { "false", []( const T& ) { return false; } },
80 { "<", [this]( const T& v ) { return v < this->m_portThreshold->data(); } },
81 { ">", [this]( const T& v ) { return v > this->m_portThreshold->data(); } } };
82
83 function_type m_currentFunction = m_functions["true"];
84
86 PortOutPtr<function_type> m_operatourOut {
87 add_output<function_type>( &m_currentFunction, "f" ) };
90 PortInPtr<std::string> m_portName { add_input<std::string>( "name", "true" ) };
91 PortInPtr<T> m_portThreshold { add_input<T>( "threshold", T {} ) };
92};
94} // namespace Customs
95
96template <typename DataType>
98template <typename DataType>
100template <typename DataType>
102template <typename DataType>
104
105// Reusable function to create a graph
106template <typename DataType>
107DataflowGraph* buildgraph( const std::string& name ) {
108 auto g = new DataflowGraph( name );
109
111 REQUIRE( g->add_node( ds ) );
112
114 REQUIRE( g->add_node( rs ) );
115
117 REQUIRE( g->add_node( ts ) );
118
120 REQUIRE( g->add_node( ss ) );
121
123 REQUIRE( g->add_node( nm ) );
124
126 REQUIRE( g->add_node( fs ) );
127
129 REQUIRE( g->add_node( fl ) );
130
131 auto coreFactory = NodeFactoriesManager::default_factory();
132
133 REGISTER_TYPE_TO_FACTORY( coreFactory, FilterCollectionType<DataType>, Functionals );
134 REGISTER_TYPE_TO_FACTORY( coreFactory, CollectionInputType<DataType>, Functionals );
135 REGISTER_TYPE_TO_FACTORY( coreFactory, CollectionOutputType<DataType>, Functionals );
136
137 REQUIRE( g->add_link( ds, "to", fl, "data" ) );
138 REQUIRE( g->add_link( fl, "result", rs, "from" ) );
139 REQUIRE( g->add_link( ss, "to", fs, "name" ) );
140 REQUIRE( g->add_link( ts, "to", fs, "threshold" ) );
141 REQUIRE( g->add_link( fs, "f", fl, "predicate" ) );
142 REQUIRE( g->add_link( fs, "name", nm, "from" ) );
143 return g;
144}
145
146// test sections
147TEST_CASE( "Dataflow/Core/Custom nodes", "[unittests][Dataflow][Core][Custom nodes]" ) {
148 SECTION( "Build graph with custom nodes" ) {
149 // build a graph
150 auto g = buildgraph<Scalar>( "testCustomNodes" );
151
152 // get input and ouput of the graph
153 auto inputCollection =
155 REQUIRE( inputCollection != nullptr );
156 auto inputOpName =
158 REQUIRE( inputOpName != nullptr );
159 auto inputThreshold =
161 REQUIRE( inputThreshold != nullptr );
162
163 auto filteredCollection = g->node( "rs" );
164 REQUIRE( filteredCollection != nullptr );
165 auto generatedOperator = g->node( "nm" );
166 REQUIRE( generatedOperator != nullptr );
167
168 // parameterize the graph
170 CollectionType testVector;
171 testVector.reserve( 10 );
172 std::mt19937 gen( 0 );
173 std::uniform_real_distribution<Scalar> dis( 0.0_ra, 1.0_ra );
174 // Fill the vector with random numbers between 0 and 1
175 for ( size_t n = 0; n < testVector.capacity(); ++n ) {
176 testVector.push_back( dis( gen ) );
177 }
178
179 inputCollection->set_data( testVector );
180 inputThreshold->set_data( .5_ra );
181 inputOpName->set_data( "true" );
182
183 // execute the graph that filter out nothing
184 REQUIRE( g->execute() );
185
186 // Getters are usable only after successful compilation/execution of the graph
187 // Get results as references (no need to get them again later if the graph does
188 // not change)
189 auto& vres = filteredCollection->input_by_name( "from" ).second->data<CollectionType>();
190 auto& vop = generatedOperator->input_by_name( "from" ).second->data<std::string>();
191
192 REQUIRE( vop == "true" );
193 REQUIRE( vres.size() == testVector.size() );
194
195 // change operator to filter out everything
196 inputOpName->set_data( "false" );
197
198 REQUIRE( g->execute() );
199 REQUIRE( vop == "false" );
200 REQUIRE( vres.size() == 0 );
201
202 // Change operator to keep element less than threshold
203 inputOpName->set_data( "<" );
204
205 REQUIRE( g->execute() );
206
207 REQUIRE( *( std::max_element( vres.begin(), vres.end() ) ) < *inputThreshold->data() );
208
209 // Change operator to keep element greater than threshold
210 inputOpName->set_data( ">" );
211 REQUIRE( g->execute() );
212 REQUIRE( *( std::max_element( vres.begin(), vres.end() ) ) > *inputThreshold->data() );
213 }
214 SECTION( "Serialization of a custom graph" ) {
215 // Create and fill the factory for the custom nodes
216 auto customFactory = NodeFactoriesManager::create_factory( "CustomNodesUnitTests" );
217
218 // add node creators to the factory
219
220 REQUIRE( customFactory->register_node_creator<Customs::CustomStringSource>(
221 Customs::CustomStringSource::node_typename() + "_", "Custom" ) );
222 REQUIRE( customFactory->register_node_creator<Customs::CustomStringSink>(
223 Customs::CustomStringSink::node_typename() + "_", "Custom" ) );
224 REQUIRE( customFactory->register_node_creator<Customs::FilterSelector<Scalar>>(
226 // The same node can't be register twice in the same factory
227 REQUIRE( !customFactory->register_node_creator<Customs::FilterSelector<Scalar>>(
229
230 nlohmann::json emptyData;
231 auto customSource = customFactory->create_node(
232 Customs::CustomStringSource::node_typename(), emptyData, nullptr );
233 REQUIRE( customSource );
234
235 // build a graph
236 auto g = buildgraph<Scalar>( "testCustomNodes" );
237
238 std::string tmpdir { "customGraphExport/" };
239 std::filesystem::create_directories( tmpdir );
240
241 // save the graph with factory
242 g->saveToJson( tmpdir + "customGraph.json" );
243
244 g->destroy();
245 delete g;
246 g = new DataflowGraph( "" );
247
248 REQUIRE( g->loadFromJson( tmpdir + "customGraph.json" ) );
249 g->destroy();
250 delete g;
251
253 auto unregistered = NodeFactoriesManager::unregister_factory( customFactory->name() );
254 REQUIRE( unregistered == true );
255
256 g = new DataflowGraph( "" );
257 REQUIRE( !g->loadFromJson( tmpdir + "customGraph.json" ) );
258 delete g;
259
260 std::filesystem::remove_all( tmpdir );
261 }
262}
T capacity(T... args)
[Develop a custom node]
bool execute() override
Executes the node.
bool fromJsonInternal(const nlohmann::json &data) override
Internal json representation of the Node.
void toJsonInternal(nlohmann::json &data) const override
Internal json representation of the Node.
This class implements ContainerIntrospectionInterface for AlignedStdVector.
Represent a set of connected nodes that define a Direct Acyclic Computational Graph Ownership of node...
Filter on iterable collection.
Base abstract class for all the nodes added and used by the node system.
Definition Node.hpp:40
PortIndex add_output(PortBaseOutPtr out)
Convenience alias to add_port(outputs(), out)
Definition Node.hpp:488
PortIndex add_input(PortBaseInPtr in)
Convenience alias to add_port(inputs(), in)
Definition Node.hpp:484
Base class for nodes that will store the result of a computation graph.
Definition SinkNode.hpp:17
Base class for nodes that will give access to some input data to the graph. This class can be used to...
T data(T... args)
T make_shared(T... args)
T max_element(T... args)
auto unregister_factory(const NodeFactorySet::key_type &name) -> bool
Unregister the factory from the manager.
auto default_factory() -> NodeFactorySet::mapped_type
Gets the "default" factory for nodes exported by the Core dataflow library.
auto create_factory(const NodeFactorySet::key_type &name) -> NodeFactorySet::mapped_type
Create and register a factory to the manager.
T dynamic_pointer_cast(T... args)
T push_back(T... args)
T reserve(T... args)
T size(T... args)