Loading [MathJax]/extensions/TeX/AMSsymbols.js
Radium Engine  1.5.28
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DataflowGraph.hpp
1#pragma once
2#include <Dataflow/RaDataflow.hpp>
3
4#include <Dataflow/Core/GraphNodes.hpp>
5#include <Dataflow/Core/Node.hpp>
6#include <Dataflow/Core/NodeFactory.hpp>
7
8#include <Core/Types.hpp>
9#include <Core/Utils/BijectiveAssociation.hpp>
10#include <Core/Utils/Color.hpp>
11#include <Core/Utils/Singleton.hpp>
12
13#include <functional>
14
15namespace Ra {
16namespace Dataflow {
17namespace Core {
18
23class RA_DATAFLOW_CORE_API DataflowGraph : public Node
24{
25 public:
29 explicit DataflowGraph( const std::string& name );
30 virtual ~DataflowGraph() = default;
31
32 void init() override;
33 bool execute() override;
34 void destroy() override;
35
42 bool loadFromJson( const std::string& jsonFilePath );
43
49 void saveToJson( const std::string& jsonFilePath );
50
58 virtual bool add_node( std::shared_ptr<Node> newNode );
60 template <typename T, typename... U>
61 std::shared_ptr<T> add_node( U&&... u );
62
68 virtual bool remove_node( std::shared_ptr<Node> node );
69
83 bool add_link( const std::shared_ptr<Node>& nodeFrom,
84 const std::string& nodeFromOutputName,
85 const std::shared_ptr<Node>& nodeTo,
86 const std::string& nodeToInputName );
88 bool add_link( const std::shared_ptr<Node>& nodeFrom,
89 Node::PortIndex portOutIdx,
90 const std::shared_ptr<Node>& nodeTo,
91 Node::PortIndex portInIdx );
93 bool add_link( Node::PortBaseOutRawPtr outputPort, Node::PortBaseInRawPtr inputPort );
94
96 template <typename T, typename U>
97 bool add_link( const std::shared_ptr<PortOut<T>>& outputPort,
98 const std::shared_ptr<PortIn<U>>& inputPort );
99
101 bool can_link( const std::shared_ptr<Node>& nodeFrom,
102 Node::PortIndex portOutIdx,
103 const std::shared_ptr<Node>& nodeTo,
104 Node::PortIndex portInIdx ) const;
105
107 bool can_link( const Node* nodeFrom,
108 Node::PortIndex portOutIdx,
109 const Node* nodeTo,
110 Node::PortIndex portInIdx ) const;
111
119 bool remove_link( std::shared_ptr<Node> node, const std::string& nodeInputName );
120
127 bool remove_link( std::shared_ptr<Node> node, const PortIndex& in_port_index );
128
130 const std::vector<std::shared_ptr<Node>>& nodes() const { return m_nodes; }
131
134 std::shared_ptr<Node> node( const std::string& instanceNameNode ) const;
135
142 template <typename T>
143 std::shared_ptr<T> node( const std::string& instanceNameNode ) const {
144 return std::dynamic_pointer_cast<T>( node( instanceNameNode ) );
145 }
146
148 const std::vector<std::vector<Node*>>& nodes_by_level() const { return m_nodes_by_level; }
149
158 bool compile() override;
159
165 void generate_ports();
166
168 size_t node_count() const { return m_nodes.size(); }
169
171 virtual void clear_nodes();
172
174 bool is_compiled() const { return m_ready; }
175
178 inline void needs_recompile();
179
190 Node::PortBaseInRawPtr input_node_port( const std::string& nodeName,
191 const std::string& portName );
203 Node::PortBaseOutRawPtr output_node_port( const std::string& nodeName,
204 const std::string& portName );
205
206 bool shouldBeSaved() { return m_should_save; }
207
208 static const std::string& node_typename();
209
218 static std::shared_ptr<DataflowGraph> loadGraphFromJsonFile( const std::string& filename );
219
225 void setNodesAndLinksProtection( bool on ) { m_nodesAndLinksProtected = on; }
226
231 bool nodesAndLinksProtection() const { return m_nodesAndLinksProtected; }
232
233 using Node::add_input;
234 using Node::add_output;
235
244 void add_input_output_nodes();
245
253 void remove_unlinked_input_output_ports();
254
255 std::shared_ptr<GraphOutputNode> output_node() { return m_output_node; }
256 std::shared_ptr<GraphInputNode> input_node() { return m_input_node; }
257
258 protected:
262 DataflowGraph( const std::string& instanceName, const std::string& typeName );
263
264 bool fromJsonInternal( const nlohmann::json& data ) override;
265 void toJsonInternal( nlohmann::json& ) const override;
266
273 bool has_node_by_name( const std::string& instance, const std::string& model ) const;
279 bool contains_node_recursive( const Node* node ) const;
280
281 private:
282 // Internal helper functions
286 void
287 backtrack_graph( Node* current,
288 std::unordered_map<Node*, std::pair<int, std::vector<Node*>>>& infoNodes );
293 int traverse_graph( Node* current,
294 std::unordered_map<Node*, std::pair<int, std::vector<Node*>>>& infoNodes );
295
298 bool check_last_port_io_nodes( const Node* nodeFrom,
299 Node::PortIndex portOutIdx,
300 const Node* nodeTo,
301 Node::PortIndex portInIdx ) const {
302 if ( nodeFrom == m_input_node.get() && portOutIdx == m_input_node->outputs().size() )
303 return true;
304 if ( nodeTo == m_output_node.get() && portInIdx == m_output_node->inputs().size() )
305 return true;
306 return false;
307 }
308
309 bool are_nodes_valids( const Node* nodeFrom, const Node* nodeTo, bool verbose = false ) const;
310 static bool are_ports_compatible( const Node* nodeFrom,
311 const PortBaseOut* portOut,
312 const Node* nodeTo,
313 const PortBaseIn* portIn );
314 class RA_DATAFLOW_CORE_API Log
315 {
316 public:
317 static void already_linked( const Node* node, const PortBase* port );
318 static void link_type_mismatch( const Node* nodeFrom,
319 const PortBase* portOut,
320 const Node* nodeTo,
321 const PortBase* portIn );
322 static void unable_to_find( const std::string& type, const std::string& instanceName );
323 static void bad_port_index( const std::string& type,
324 const std::string& instanceName,
325 Node::PortIndex idx );
326 static void try_to_link_input_to_output();
327 };
328
331 bool m_should_save { false };
332
335 bool m_ready { false };
336
339 std::shared_ptr<GraphOutputNode> m_output_node { nullptr };
340 std::shared_ptr<GraphInputNode> m_input_node { nullptr };
341
344 std::vector<std::vector<Node*>> m_nodes_by_level;
345
346 bool m_nodesAndLinksProtected { false };
347};
348
349// -----------------------------------------------------------------
350// ---------------------- inline methods ---------------------------
351
352template <typename T, typename... U>
354 auto ret = std::make_shared<T>( std::forward<U>( u )... );
355 if ( add_node( ret ) ) return ret;
356 return nullptr;
357}
358
359template <typename T, typename U>
361 const std::shared_ptr<PortIn<U>>& inputPort ) {
362 using namespace Ra::Core::Utils;
363
364 static_assert( std::is_same_v<T, U>, "in and out port's types mismatch" );
365
366 return add_link( outputPort.get(), inputPort.get() );
367}
368
370 Node::PortIndex portOutIdx,
371 const std::shared_ptr<Node>& nodeTo,
372 Node::PortIndex portInIdx ) const {
373 return can_link( nodeFrom.get(), portOutIdx, nodeTo.get(), portInIdx );
374}
375
376inline bool DataflowGraph::can_link( const Node* nodeFrom,
377 Node::PortIndex portOutIdx,
378 const Node* nodeTo,
379 Node::PortIndex portInIdx ) const {
380 auto portIn = nodeTo->input_by_index( portInIdx );
381 auto portOut = nodeFrom->output_by_index( portOutIdx );
382
383 if ( !are_nodes_valids( nodeFrom, nodeTo ) ) { return false; }
384 if ( check_last_port_io_nodes( nodeFrom, portOutIdx, nodeTo, portInIdx ) ) {
385 if ( nodeFrom == m_input_node.get() ) return portIn != nullptr;
386 if ( nodeTo == m_output_node.get() ) return portOut != nullptr;
387 }
388
389 // Compare types
390 return portIn && portOut && ( portIn->type() == portOut->type() && !portIn->is_linked() );
391}
392
394 m_should_save = true;
395 m_ready = false;
396}
397
398inline Node::PortBaseInRawPtr DataflowGraph::input_node_port( const std::string& nodeName,
399 const std::string& portName ) {
400 auto n = node( nodeName );
401 auto p = n->input_by_name( portName );
402 CORE_ASSERT( p.first.isValid(), "invalid port, node: " + nodeName + " port: " + portName );
403 return p.second;
404}
405
406inline Node::PortBaseOutRawPtr DataflowGraph::output_node_port( const std::string& nodeName,
407 const std::string& portName ) {
408 auto n = node( nodeName );
409 auto p = n->output_by_name( portName );
410 CORE_ASSERT( p.first.isValid(), "invalid port, node: " + nodeName + " port: " + portName );
411
412 return p.second;
413}
414
416 if ( !m_input_node ) { m_input_node = std::make_shared<GraphInputNode>( "input" ); }
417 if ( !m_output_node ) { m_output_node = std::make_shared<GraphOutputNode>( "output" ); }
418 m_input_node->set_graph( this );
419 m_output_node->set_graph( this );
420 add_node( m_input_node );
421 add_node( m_output_node );
422}
423
425 if ( m_input_node ) { m_input_node->remove_unlinked_ports(); }
426 if ( m_output_node ) { m_output_node->remove_unlinked_ports(); }
428}
429
430inline const std::string& DataflowGraph::node_typename() {
431 static std::string demangledTypeName { "Core DataflowGraph" };
432 return demangledTypeName;
433}
434
435} // namespace Core
436} // namespace Dataflow
437} // namespace Ra
Represent a set of connected nodes that define a Direct Acyclic Computational Graph Ownership of node...
bool add_link(const std::shared_ptr< Node > &nodeFrom, const std::string &nodeFromOutputName, const std::shared_ptr< Node > &nodeTo, const std::string &nodeToInputName)
Connects two nodes of the graph.
void setNodesAndLinksProtection(bool on)
protect nodes and links from deletion.
virtual bool add_node(std::shared_ptr< Node > newNode)
Adds a node to the graph.
Node::PortBaseInRawPtr input_node_port(const std::string &nodeName, const std::string &portName)
Gets an input port form a node of the graph.
size_t node_count() const
Gets the number of nodes.
const std::vector< std::shared_ptr< Node > > & nodes() const
Get the vector of all the nodes on the graph.
std::shared_ptr< Node > node(const std::string &instanceNameNode) const
void generate_ports()
fill input and output ports of graph from its input and output nodes if exists.
void remove_unlinked_input_output_ports()
Removes unsused (unlinked) input/output ports.
bool is_compiled() const
Test if the graph is compiled.
std::shared_ptr< T > node(const std::string &instanceNameNode) const
Node::PortBaseOutRawPtr output_node_port(const std::string &nodeName, const std::string &portName)
Gets an output port from a node of the graph.
const std::vector< std::vector< Node * > > & nodes_by_level() const
Gets the nodes ordered by level (after compilation)
bool can_link(const std::shared_ptr< Node > &nodeFrom, Node::PortIndex portOutIdx, const std::shared_ptr< Node > &nodeTo, Node::PortIndex portInIdx) const
bool nodesAndLinksProtection() const
get the protection status protect nodes and links from deletion
void add_input_output_nodes()
Create (if not already created) input/output node of the graph, and fills graph input/output.
Base abstract class for all the nodes added and used by the node system.
Definition Node.hpp:40
auto input_by_index(PortIndex index) const
Convenience alias to port_by_index("in", index)
Definition Node.hpp:161
auto output_by_index(PortIndex index) const
Convenience alias to port_by_index("out", index)
Definition Node.hpp:163
Input port accepting data of type T.
Definition PortOut.hpp:17
Forward PortOut classes used by getLink and reflect.
Definition PortOut.hpp:73
T forward(T... args)
T get(T... args)
T make_shared(T... args)
hepler function to manage enum as underlying types in VariableSet
Definition Cage.cpp:4
T dynamic_pointer_cast(T... args)