From f6bfa9812e8fdc876d78a565b141efeba8ea6b41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beat=20K=C3=BCng?= Date: Tue, 29 Oct 2024 14:50:08 +0100 Subject: [PATCH] msg: add message translation node for ROS --- Tools/astyle/files_to_check_code_style.sh | 1 + Tools/copy_to_ros_ws.sh | 34 + msg/px4_msgs_old/CMakeLists.txt | 76 +++ msg/px4_msgs_old/package.xml | 25 + msg/px4_msgs_old/rename_msg_type.py.in | 39 ++ msg/translation_node/CMakeLists.txt | 82 +++ msg/translation_node/README.md | 6 + msg/translation_node/package.xml | 27 + msg/translation_node/src/graph.h | 293 ++++++++ msg/translation_node/src/main.cpp | 39 ++ msg/translation_node/src/monitor.cpp | 60 ++ msg/translation_node/src/monitor.h | 23 + msg/translation_node/src/pub_sub_graph.cpp | 195 ++++++ msg/translation_node/src/pub_sub_graph.h | 58 ++ msg/translation_node/src/service_graph.cpp | 230 +++++++ msg/translation_node/src/service_graph.h | 76 +++ msg/translation_node/src/template_util.h | 64 ++ msg/translation_node/src/translation_util.h | 386 +++++++++++ msg/translation_node/src/translations.cpp | 5 + msg/translation_node/src/translations.h | 91 +++ msg/translation_node/src/util.h | 51 ++ msg/translation_node/test/graph.cpp | 623 ++++++++++++++++++ msg/translation_node/test/main.cpp | 16 + msg/translation_node/test/pub_sub.cpp | 350 ++++++++++ msg/translation_node/test/services.cpp | 215 ++++++ msg/translation_node/test/srv/TestV0.srv | 4 + msg/translation_node/test/srv/TestV1.srv | 4 + msg/translation_node/test/srv/TestV2.srv | 6 + .../translations/all_translations.h | 11 + .../example_translation_direct_v1.h | 30 + .../example_translation_multi_v2.h | 42 ++ .../example_translation_service_v1.h | 38 ++ 32 files changed, 3200 insertions(+) create mode 100755 Tools/copy_to_ros_ws.sh create mode 100644 msg/px4_msgs_old/CMakeLists.txt create mode 100644 msg/px4_msgs_old/package.xml create mode 100755 msg/px4_msgs_old/rename_msg_type.py.in create mode 100644 msg/translation_node/CMakeLists.txt create mode 100644 msg/translation_node/README.md create mode 100644 msg/translation_node/package.xml create mode 100644 msg/translation_node/src/graph.h create mode 100644 msg/translation_node/src/main.cpp create mode 100644 msg/translation_node/src/monitor.cpp create mode 100644 msg/translation_node/src/monitor.h create mode 100644 msg/translation_node/src/pub_sub_graph.cpp create mode 100644 msg/translation_node/src/pub_sub_graph.h create mode 100644 msg/translation_node/src/service_graph.cpp create mode 100644 msg/translation_node/src/service_graph.h create mode 100644 msg/translation_node/src/template_util.h create mode 100644 msg/translation_node/src/translation_util.h create mode 100644 msg/translation_node/src/translations.cpp create mode 100644 msg/translation_node/src/translations.h create mode 100644 msg/translation_node/src/util.h create mode 100644 msg/translation_node/test/graph.cpp create mode 100644 msg/translation_node/test/main.cpp create mode 100644 msg/translation_node/test/pub_sub.cpp create mode 100644 msg/translation_node/test/services.cpp create mode 100644 msg/translation_node/test/srv/TestV0.srv create mode 100644 msg/translation_node/test/srv/TestV1.srv create mode 100644 msg/translation_node/test/srv/TestV2.srv create mode 100644 msg/translation_node/translations/all_translations.h create mode 100644 msg/translation_node/translations/example_translation_direct_v1.h create mode 100644 msg/translation_node/translations/example_translation_multi_v2.h create mode 100644 msg/translation_node/translations/example_translation_service_v1.h diff --git a/Tools/astyle/files_to_check_code_style.sh b/Tools/astyle/files_to_check_code_style.sh index bec6856b75..266760d7af 100755 --- a/Tools/astyle/files_to_check_code_style.sh +++ b/Tools/astyle/files_to_check_code_style.sh @@ -8,6 +8,7 @@ if [ $# -gt 0 ]; then fi exec find boards msg src platforms test \ + -path msg/translation_node -prune -o \ -path platforms/nuttx/NuttX -prune -o \ -path platforms/qurt/dspal -prune -o \ -path src/drivers/ins/vectornav/libvnc -prune -o \ diff --git a/Tools/copy_to_ros_ws.sh b/Tools/copy_to_ros_ws.sh new file mode 100755 index 0000000000..ef663aaaa9 --- /dev/null +++ b/Tools/copy_to_ros_ws.sh @@ -0,0 +1,34 @@ +#! /bin/bash +# Copy msgs and the message translation node into a ROS workspace directory + +DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) + +PX4_SRC_DIR="$DIR/.." + +WS_DIR="$1" +if [ ! -e "${WS_DIR}" ] +then + echo "Usage: $0 " + exit 1 +fi +WS_DIR="$WS_DIR"/src +if [ ! -e "${WS_DIR}" ] +then + echo "'src' directory not found inside ROS workspace (${WS_DIR})" + exit 1 +fi + +cp -ar "${PX4_SRC_DIR}"/msg/translation_node "${WS_DIR}" +cp -ar "${PX4_SRC_DIR}"/msg/px4_msgs_old "${WS_DIR}" +PX4_MSGS_DIR="${WS_DIR}"/px4_msgs +if [ ! -e "${PX4_MSGS_DIR}" ] +then + git clone https://github.com/PX4/px4_msgs.git "${PX4_MSGS_DIR}" + rm -rf "${PX4_MSGS_DIR}"/msg/*.msg + rm -rf "${PX4_MSGS_DIR}"/msg/versioned/*.msg + rm -rf "${PX4_MSGS_DIR}"/srv/*.srv +fi +cp -ar "${PX4_SRC_DIR}"/msg/*.msg "${PX4_MSGS_DIR}"/msg +mkdir -p "${PX4_MSGS_DIR}"/msg/versioned +cp -ar "${PX4_SRC_DIR}"/msg/versioned/*.msg "${PX4_MSGS_DIR}"/msg/versioned +cp -ar "${PX4_SRC_DIR}"/srv/*.srv "${PX4_MSGS_DIR}"/srv diff --git a/msg/px4_msgs_old/CMakeLists.txt b/msg/px4_msgs_old/CMakeLists.txt new file mode 100644 index 0000000000..add08bc670 --- /dev/null +++ b/msg/px4_msgs_old/CMakeLists.txt @@ -0,0 +1,76 @@ +cmake_minimum_required(VERSION 3.5) + +project(px4_msgs_old) + +list(INSERT CMAKE_MODULE_PATH 0 "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra) +endif() + +find_package(ament_cmake REQUIRED) +find_package(builtin_interfaces REQUIRED) +find_package(rosidl_default_generators REQUIRED) + +# ############################################################################## +# Generate ROS messages, ROS2 interfaces and IDL files # +# ############################################################################## + +# get all msg files +set(MSGS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/msg") +file(GLOB PX4_MSGS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${MSGS_DIR}/*.msg") + +# get all srv files +set(SRVS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/srv") +file(GLOB PX4_SRVS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${SRVS_DIR}/*.srv") + + + +# For the versioned topics, replace the namespace (px4_msgs_old -> px4_msgs) and message type name (Vx -> ), +# so that DDS does not reject the subscription/publication due to mismatching type +# rosidl_typesupport_fastrtps_cpp +set(rosidl_typesupport_fastrtps_cpp_BIN ${CMAKE_CURRENT_BINARY_DIR}/rosidl_typesupport_fastrtps_cpp_wrapper.py) +file(TOUCH ${rosidl_typesupport_fastrtps_cpp_BIN}) + +# rosidl_typesupport_fastrtps_c +set(rosidl_typesupport_fastrtps_c_BIN ${CMAKE_CURRENT_BINARY_DIR}/rosidl_typesupport_fastrtps_c_wrapper.py) +file(TOUCH ${rosidl_typesupport_fastrtps_c_BIN}) + +# rosidl_typesupport_introspection_cpp (for cyclonedds) +set(rosidl_typesupport_introspection_cpp_BIN ${CMAKE_CURRENT_BINARY_DIR}/rosidl_typesupport_introspection_cpp_wrapper.py) +file(TOUCH ${rosidl_typesupport_introspection_cpp_BIN}) + +# Generate introspection typesupport for C and C++ and IDL files +if(PX4_MSGS) + rosidl_generate_interfaces(${PROJECT_NAME} + ${PX4_MSGS} + ${PX4_SRVS} + DEPENDENCIES builtin_interfaces + ADD_LINTER_TESTS + ) +endif() + +# rosidl_typesupport_fastrtps_cpp +set(rosidl_typesupport_fastrtps_cpp_orig ${rosidl_typesupport_fastrtps_cpp_DIR}) +string(REPLACE "share/rosidl_typesupport_fastrtps_cpp/cmake" "lib/rosidl_typesupport_fastrtps_cpp/rosidl_typesupport_fastrtps_cpp" + rosidl_typesupport_fastrtps_cpp_orig ${rosidl_typesupport_fastrtps_cpp_DIR}) +set(original_script_path ${rosidl_typesupport_fastrtps_cpp_orig}) +configure_file(rename_msg_type.py.in ${rosidl_typesupport_fastrtps_cpp_BIN} @ONLY) + +# rosidl_typesupport_fastrtps_c +set(rosidl_typesupport_fastrtps_c_orig ${rosidl_typesupport_fastrtps_c_DIR}) +string(REPLACE "share/rosidl_typesupport_fastrtps_c/cmake" "lib/rosidl_typesupport_fastrtps_c/rosidl_typesupport_fastrtps_c" + rosidl_typesupport_fastrtps_c_orig ${rosidl_typesupport_fastrtps_c_DIR}) +set(original_script_path ${rosidl_typesupport_fastrtps_c_orig}) +configure_file(rename_msg_type.py.in ${rosidl_typesupport_fastrtps_c_BIN} @ONLY) + +# rosidl_typesupport_introspection_cpp +set(rosidl_typesupport_introspection_cpp_orig ${rosidl_typesupport_introspection_cpp_DIR}) +string(REPLACE "share/rosidl_typesupport_introspection_cpp/cmake" "lib/rosidl_typesupport_introspection_cpp/rosidl_typesupport_introspection_cpp" + rosidl_typesupport_introspection_cpp_orig ${rosidl_typesupport_introspection_cpp_DIR}) +set(original_script_path ${rosidl_typesupport_introspection_cpp_orig}) +configure_file(rename_msg_type.py.in ${rosidl_typesupport_introspection_cpp_BIN} @ONLY) + +ament_export_dependencies(rosidl_default_runtime) + +ament_package() diff --git a/msg/px4_msgs_old/package.xml b/msg/px4_msgs_old/package.xml new file mode 100644 index 0000000000..42b83c6f9a --- /dev/null +++ b/msg/px4_msgs_old/package.xml @@ -0,0 +1,25 @@ + + + + px4_msgs_old + 2.0.1 + Package with the ROS-equivalent of PX4 uORB msgs (old message definitions) + PX4 + BSD 3-Clause + + ament_cmake + rosidl_default_generators + + builtin_interfaces + ros_environment + + rosidl_default_runtime + + ament_lint_common + + rosidl_interface_packages + + + ament_cmake + + diff --git a/msg/px4_msgs_old/rename_msg_type.py.in b/msg/px4_msgs_old/rename_msg_type.py.in new file mode 100755 index 0000000000..31d8054e95 --- /dev/null +++ b/msg/px4_msgs_old/rename_msg_type.py.in @@ -0,0 +1,39 @@ +#! /bin/python +import sys +import subprocess +import json +import os +import re + +original_script = "@original_script_path@" +args = sys.argv[1:] + +json_file = [arg for arg in args if arg.endswith('.json')][0] + +proc = subprocess.run(['python3', original_script] + args) +proc.check_returncode() + +def replace_namespace_and_type(content: str): + # Replace namespace type + content = content.replace('"px4_msgs_old"', '"px4_msgs"') + content = content.replace('"px4_msgs_old::msg"', '"px4_msgs::msg"') + # Replace versioned type with non-versioned one + content = re.sub(r'("[a-zA-Z0-9]+)V[0-9]+"', '\\1"', content) + # Services + content = content.replace('"px4_msgs_old::srv"', '"px4_msgs::srv"') + content = re.sub(r'("[a-zA-Z0-9]+)V[0-9]+_Request"', '\\1_Request"', content) + content = re.sub(r'("[a-zA-Z0-9]+)V[0-9]+_Response"', '\\1_Response"', content) + return content + +with open(json_file, 'r') as f: + data = json.load(f) + output_dir = data['output_dir'] + + # Iterate files recursively + for root, dirs, files in os.walk(output_dir): + for file in files: + with open(os.path.join(root, file), 'r+') as f: + content = f.read() + f.seek(0) + f.write(replace_namespace_and_type(content)) + f.truncate() diff --git a/msg/translation_node/CMakeLists.txt b/msg/translation_node/CMakeLists.txt new file mode 100644 index 0000000000..4880709ce2 --- /dev/null +++ b/msg/translation_node/CMakeLists.txt @@ -0,0 +1,82 @@ +cmake_minimum_required(VERSION 3.8) +project(translation_node) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic -Wno-unused-parameter -Werror) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +find_package(rclcpp REQUIRED) +find_package(px4_msgs REQUIRED) +find_package(px4_msgs_old REQUIRED) + +if(DEFINED ENV{ROS_DISTRO}) + set(ROS_DISTRO $ENV{ROS_DISTRO}) +else() + set(ROS_DISTRO "rolling") +endif() + + +add_library(${PROJECT_NAME}_lib + src/monitor.cpp + src/pub_sub_graph.cpp + src/service_graph.cpp + src/translations.cpp +) +ament_target_dependencies(${PROJECT_NAME}_lib rclcpp px4_msgs px4_msgs_old) +add_executable(${PROJECT_NAME}_bin + src/main.cpp +) +target_link_libraries(${PROJECT_NAME}_bin ${PROJECT_NAME}_lib) +target_include_directories(${PROJECT_NAME}_bin PUBLIC src) +ament_target_dependencies(${PROJECT_NAME}_bin rclcpp px4_msgs px4_msgs_old) +install(TARGETS + ${PROJECT_NAME}_bin + DESTINATION lib/${PROJECT_NAME}) + +option(DISABLE_SERVICES "Disable services" OFF) +if(${ROS_DISTRO} STREQUAL "humble") + message(WARNING "Disabling services for ROS humble (API is not supported)") + target_compile_definitions(${PROJECT_NAME}_lib PRIVATE DISABLE_SERVICES) + set(DISABLE_SERVICES ON) +endif() + +if(BUILD_TESTING) + find_package(std_msgs REQUIRED) + find_package(ament_lint_auto REQUIRED) + find_package(ament_cmake_gtest REQUIRED) + find_package(rosidl_default_generators REQUIRED) + ament_lint_auto_find_test_dependencies() + + set(SRV_FILES + test/srv/TestV0.srv + test/srv/TestV1.srv + test/srv/TestV2.srv + ) + rosidl_generate_interfaces(${PROJECT_NAME} ${SRV_FILES}) + + # Unit tests + set(TEST_SRC + test/graph.cpp + test/main.cpp + test/pub_sub.cpp + ) + if (NOT DISABLE_SERVICES) + list(APPEND TEST_SRC test/services.cpp) + endif() + ament_add_gtest(${PROJECT_NAME}_unit_tests + ${TEST_SRC} + ) + target_include_directories(${PROJECT_NAME}_unit_tests PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_compile_options(${PROJECT_NAME}_unit_tests PRIVATE -Wno-error=sign-compare) # There is a warning from gtest internal + target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME}_lib) + rosidl_get_typesupport_target(cpp_typesupport_target ${PROJECT_NAME} "rosidl_typesupport_cpp") + target_link_libraries(${PROJECT_NAME}_unit_tests "${cpp_typesupport_target}") + ament_target_dependencies(${PROJECT_NAME}_unit_tests + std_msgs + rclcpp + ) +endif() + +ament_package() diff --git a/msg/translation_node/README.md b/msg/translation_node/README.md new file mode 100644 index 0000000000..69e0843bb5 --- /dev/null +++ b/msg/translation_node/README.md @@ -0,0 +1,6 @@ +# Message Translations + +This package contains a message translation node and a set of old message conversion methods. +It allows to run applications that are compiled with one set of message versions against a PX4 with another set of message versions, without having to change either the application or the PX4 side. + +For details, see https://docs.px4.io/main/en/ros2/px4_ros2_msg_translation_node.html. diff --git a/msg/translation_node/package.xml b/msg/translation_node/package.xml new file mode 100644 index 0000000000..7d6e03511b --- /dev/null +++ b/msg/translation_node/package.xml @@ -0,0 +1,27 @@ + + + + translation_node + 0.0.0 + Message version translation node + PX4 + BSD 3-Clause + + ament_cmake + rosidl_default_generators + + rosidl_interface_packages + + ament_lint_auto + ament_lint_common + ament_cmake_gtest + std_msgs + + rclcpp + px4_msgs + px4_msgs_old + + + ament_cmake + + diff --git a/msg/translation_node/src/graph.h b/msg/translation_node/src/graph.h new file mode 100644 index 0000000000..0feac52fa9 --- /dev/null +++ b/msg/translation_node/src/graph.h @@ -0,0 +1,293 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include "util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// This implements a directed graph with potential cycles used for translation. +// There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and +// translations. Translation nodes are always in between message nodes, and can have N input messages +// and M output messages. + +struct MessageIdentifier { + std::string topic_name; + MessageVersionType version; + + bool operator==(const MessageIdentifier& other) const { + return topic_name == other.topic_name && version == other.version; + } + bool operator!=(const MessageIdentifier& other) const { + return !(*this == other); + } +}; + +template<> +struct std::hash +{ + std::size_t operator()(const MessageIdentifier& s) const noexcept + { + std::size_t h1 = std::hash{}(s.topic_name); + std::size_t h2 = std::hash{}(s.version); + return h1 ^ (h2 << 1); + } +}; + + +using MessageBuffer = std::shared_ptr; + +template +class MessageNode; +template +class Graph; + +template +using MessageNodePtrT = std::shared_ptr>; + +template +class TranslationNode { +public: + using TranslationCB = std::function&, std::vector&)>; + + TranslationNode(std::vector> inputs, + std::vector> outputs, + TranslationCB translation_db) + : _inputs(std::move(inputs)), _outputs(std::move(outputs)), _translation_cb(std::move(translation_db)) { + assert(_inputs.size() <= kMaxNumInputs); + + _input_buffers.resize(_inputs.size()); + for (unsigned i = 0; i < _inputs.size(); ++i) { + _input_buffers[i] = _inputs[i]->buffer(); + } + + _output_buffers.resize(_outputs.size()); + for (unsigned i = 0; i < _outputs.size(); ++i) { + _output_buffers[i] = _outputs[i]->buffer(); + } + } + + void setInputReady(unsigned index) { + _inputs_ready.set(index); + } + + bool translate() { + if (_inputs_ready.count() == _input_buffers.size()) { + _translation_cb(_input_buffers, _output_buffers); + _inputs_ready.reset(); + return true; + } + return false; + } + + const std::vector>& inputs() const { return _inputs; } + const std::vector>& outputs() const { return _outputs; } + +private: + static constexpr int kMaxNumInputs = 32; + + const std::vector> _inputs; + std::vector _input_buffers; ///< Cached buffers from _inputs.buffer() + const std::vector> _outputs; + std::vector _output_buffers; + const TranslationCB _translation_cb; + + std::bitset _inputs_ready; +}; + +template +using TranslationNodePtrT = std::shared_ptr>; + + +template +class MessageNode { +public: + + explicit MessageNode(NodeData node_data, size_t index, MessageBuffer message_buffer) + : _buffer(std::move(message_buffer)), _data(std::move(node_data)), _index(index) {} + + MessageBuffer& buffer() { return _buffer; } + + void addTranslationInput(TranslationNodePtrT node, unsigned input_index) { + _translations.push_back(Translation{std::move(node), input_index}); + } + + NodeData& data() { return _data; } + + void resetNodes() { + _translations.clear(); + } + +private: + struct Translation { + TranslationNodePtrT node; ///< Counterpart to the TranslationNode::_inputs + unsigned input_index; ///< Index into the TranslationNode::_inputs + }; + MessageBuffer _buffer; + std::vector _translations; + + NodeData _data; + + const size_t _index; + + friend class Graph; +}; + +template +class Graph { +public: + using MessageNodePtr = MessageNodePtrT; + + ~Graph() { + // Explicitly reset the nodes array to break up potential cycles and prevent memory leaks + for (auto& [id, node] : _nodes) { + node->resetNodes(); + } + } + + /** + * @brief Add a message node if it does not exist already + */ + bool addNodeIfNotExists(const IdType& id, NodeData node_data, const MessageBuffer& message_buffer) { + if (_nodes.find(id) != _nodes.end()) { + return false; + } + // Node that we cannot remove nodes due to using the index as an array index + const size_t index = _nodes.size(); + _nodes.insert({id, std::make_shared>(std::move(node_data), index, message_buffer)}); + return true; + } + + /** + * @brief Add a translation edge with N inputs and M output nodes. All nodes must already exist. + */ + void addTranslation(const typename TranslationNode::TranslationCB& translation_cb, + const std::vector& inputs, const std::vector& outputs) { + auto init = [this](const std::vector& from, std::vector>& to) { + for (unsigned i=0; i < from.size(); ++i) { + auto node_iter = _nodes.find(from[i]); + assert(node_iter != _nodes.end()); + to[i] = node_iter->second; + } + }; + std::vector> input_nodes(inputs.size()); + init(inputs, input_nodes); + std::vector> output_nodes(outputs.size()); + init(outputs, output_nodes); + + auto translation_node = std::make_shared>(std::move(input_nodes), std::move(output_nodes), translation_cb); + for (unsigned i=0; i < translation_node->inputs().size(); ++i) { + translation_node->inputs()[i]->addTranslationInput(translation_node, i); + } + } + + + /** + * @brief Translate a message node in the graph. + * + * @param node The message node to translate. + * @param on_translated A callback function that is called for translated nodes (with an updated message buffer). + * This will not be called for the provided node. + */ + void translate(const MessageNodePtr& node, + const std::function& on_translated) { + resetNodesVisited(); + + // Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm, + // while using translation nodes as barriers (only continue when all inputs are ready) + + std::queue queue; + _node_visited[node->_index] = true; + queue.push(node); + + while (!queue.empty()) { + MessageNodePtr current = queue.front(); + queue.pop(); + for (auto& translation : current->_translations) { + const bool any_output_visited = + std::any_of(translation.node->outputs().begin(), translation.node->outputs().end(), [&](const MessageNodePtr& next_node) { + return _node_visited[next_node->_index]; + }); + // If any output node has already been visited, skip this translation node (prevents translating + // backwards, from where we came from already) + if (any_output_visited) { + continue; + } + translation.node->setInputReady(translation.input_index); + // Iterate the output nodes only if the translation node is ready + if (translation.node->translate()) { + + for (auto &next_node : translation.node->outputs()) { + if (_node_visited[next_node->_index]) { + continue; + } + _node_visited[next_node->_index] = true; + on_translated(next_node); + queue.push(next_node); + } + } + } + } + } + + std::optional findNode(const IdType& id) const { + auto iter = _nodes.find(id); + if (iter == _nodes.end()) { + return std::nullopt; + } + return iter->second; + } + + void iterateNodes(const std::function& cb) const { + for (const auto& [id, node] : _nodes) { + cb(id, node); + } + } + + /** + * Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm + */ + void iterateBFS(const MessageNodePtr& node, const std::function& cb) { + resetNodesVisited(); + + std::queue queue; + _node_visited[node->_index] = true; + queue.push(node); + cb(node); + + while (!queue.empty()) { + MessageNodePtr current = queue.front(); + queue.pop(); + for (auto& translation : current->_translations) { + for (auto& next_node : translation.node->outputs()) { + if (_node_visited[next_node->_index]) { + continue; + } + _node_visited[next_node->_index] = true; + queue.push(next_node); + + cb(next_node); + } + } + } + } + + +private: + void resetNodesVisited() { + _node_visited.resize(_nodes.size()); + std::fill(_node_visited.begin(), _node_visited.end(), false); + } + + std::unordered_map _nodes; + std::vector _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration +}; diff --git a/msg/translation_node/src/main.cpp b/msg/translation_node/src/main.cpp new file mode 100644 index 0000000000..218583d85d --- /dev/null +++ b/msg/translation_node/src/main.cpp @@ -0,0 +1,39 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include + +#include + +#include "../translations/all_translations.h" +#include "pub_sub_graph.h" +#include "service_graph.h" +#include "monitor.h" + +using namespace std::chrono_literals; + +class RosTranslationNode : public rclcpp::Node +{ +public: + RosTranslationNode() : Node("translation_node") + { + _pub_sub_graph = std::make_unique(*this, RegisteredTranslations::instance().topicTranslations()); + _service_graph = std::make_unique(*this, RegisteredTranslations::instance().serviceTranslations()); + _monitor = std::make_unique(*this, _pub_sub_graph.get(), _service_graph.get()); + } + +private: + std::unique_ptr _pub_sub_graph; + std::unique_ptr _service_graph; + rclcpp::TimerBase::SharedPtr _node_update_timer; + std::unique_ptr _monitor; +}; + +int main(int argc, char * argv[]) +{ + rclcpp::init(argc, argv); + rclcpp::spin(std::make_shared()); + rclcpp::shutdown(); + return 0; +} diff --git a/msg/translation_node/src/monitor.cpp b/msg/translation_node/src/monitor.cpp new file mode 100644 index 0000000000..9746ec3f89 --- /dev/null +++ b/msg/translation_node/src/monitor.cpp @@ -0,0 +1,60 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include "monitor.h" +using namespace std::chrono_literals; + +Monitor::Monitor(rclcpp::Node &node, PubSubGraph* pub_sub_graph, ServiceGraph* service_graph) + : _node(node), _pub_sub_graph(pub_sub_graph), _service_graph(service_graph) { + + // Monitor subscriptions & publishers + // TODO: event-based + _node_update_timer = _node.create_wall_timer(1s, [this]() { + updateNow(); + }); +} + +void Monitor::updateNow() { + + // Topics + if (_pub_sub_graph != nullptr) { + std::vector topic_info; + const auto topics = _node.get_topic_names_and_types(); + for (const auto &[topic_name, topic_types]: topics) { + auto publishers = _node.get_publishers_info_by_topic(topic_name); + auto subscribers = _node.get_subscriptions_info_by_topic(topic_name); + // Filter out self + int num_publishers = 0; + for (const auto &publisher: publishers) { + num_publishers += publisher.node_name() != _node.get_name(); + } + int num_subscribers = 0; + for (const auto &subscriber: subscribers) { + num_subscribers += subscriber.node_name() != _node.get_name(); + } + + if (num_subscribers > 0 || num_publishers > 0) { + topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers}); + } + } + _pub_sub_graph->updateCurrentTopics(topic_info); + } + + // Services +#ifndef DISABLE_SERVICES // ROS Humble does not support the count_services() call + if (_service_graph != nullptr) { + std::vector service_info; + const auto services = _node.get_service_names_and_types(); + for (const auto& [service_name, service_types] : services) { + const int num_services = _node.get_node_graph_interface()->count_services(service_name); + const int num_clients = _node.get_node_graph_interface()->count_clients(service_name); + // We cannot filter out our own node, as we don't have that info. + // We could use `get_service_names_and_types_by_node`, but then we would not get + // services by non-ros nodes (e.g. microxrce dds bridge) + service_info.emplace_back(ServiceGraph::ServiceInfo{service_name, num_services, num_clients}); + } + _service_graph->updateCurrentServices(service_info); + } +#endif +} diff --git a/msg/translation_node/src/monitor.h b/msg/translation_node/src/monitor.h new file mode 100644 index 0000000000..cd335aa83b --- /dev/null +++ b/msg/translation_node/src/monitor.h @@ -0,0 +1,23 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include "pub_sub_graph.h" +#include "service_graph.h" +#include + +class Monitor { +public: + explicit Monitor(rclcpp::Node &node, PubSubGraph* pub_sub_graph, ServiceGraph* service_graph); + + void updateNow(); + +private: + rclcpp::Node &_node; + PubSubGraph* _pub_sub_graph{nullptr}; + ServiceGraph* _service_graph{nullptr}; + rclcpp::TimerBase::SharedPtr _node_update_timer; +}; diff --git a/msg/translation_node/src/pub_sub_graph.cpp b/msg/translation_node/src/pub_sub_graph.cpp new file mode 100644 index 0000000000..f75d20fe96 --- /dev/null +++ b/msg/translation_node/src/pub_sub_graph.cpp @@ -0,0 +1,195 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include "pub_sub_graph.h" +#include "util.h" + +PubSubGraph::PubSubGraph(rclcpp::Node &node, const TopicTranslations &translations) : _node(node) { + + std::unordered_map> known_versions; + + for (const auto& topic : translations.topics()) { + const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), topic.id.topic_name); + _known_topics_warned.insert({full_topic_name, false}); + + const MessageIdentifier id{full_topic_name, topic.id.version}; + NodeDataPubSub node_data{topic.subscription_factory, topic.publication_factory, id, topic.max_serialized_message_size}; + _pub_sub_graph.addNodeIfNotExists(id, std::move(node_data), topic.message_buffer); + known_versions[full_topic_name].insert(id.version); + } + + auto get_full_topic_names = [this](std::vector ids) { + for (auto& id : ids) { + id.topic_name = getFullTopicName(_node.get_effective_namespace(), id.topic_name); + } + return ids; + }; + + for (const auto& translation : translations.translations()) { + const std::vector inputs = get_full_topic_names(translation.inputs); + const std::vector outputs = get_full_topic_names(translation.outputs); + _pub_sub_graph.addTranslation(translation.cb, inputs, outputs); + } + + printTopicInfo(known_versions); + handleLargestTopic(known_versions); +} + +void PubSubGraph::updateCurrentTopics(const std::vector &topics) { + + _pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph::MessageNodePtr& node) { + node->data().has_external_publisher = false; + node->data().has_external_subscriber = false; + node->data().visited = false; + }); + + for (const auto& info : topics) { + const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.topic_name); + auto maybe_node = _pub_sub_graph.findNode({non_versioned_topic_name, version}); + if (!maybe_node) { + auto known_topic_iter = _known_topics_warned.find(non_versioned_topic_name); + if (known_topic_iter != _known_topics_warned.end() && !known_topic_iter->second) { + RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of topic %s", version, non_versioned_topic_name.c_str()); + known_topic_iter->second = true; + } + continue; + } + const auto& node = maybe_node.value(); + + if (info.num_publishers > 0) { + node->data().has_external_publisher = true; + } + if (info.num_subscribers > 0) { + node->data().has_external_subscriber = true; + } + } + + // Iterate connected graph segments + _pub_sub_graph.iterateNodes([this](const MessageIdentifier& type, const Graph::MessageNodePtr& node) { + if (node->data().visited) { + return; + } + node->data().visited = true; + + // Count the number of external subscribers and publishers for each connected graph + int num_publishers = 0; + int num_subscribers = 0; + int num_subscribers_without_publisher = 0; + + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + if (node->data().has_external_publisher) { + ++num_publishers; + } + if (node->data().has_external_subscriber) { + ++num_subscribers; + if (!node->data().has_external_publisher) { + ++num_subscribers_without_publisher; + } + } + }); + + // We need to instantiate publishers and subscribers if: + // - there are multiple publishers and at least 1 subscriber + // - there is 1 publisher and at least 1 subscriber on another node + // Note that in case of splitting or merging topics, this might create more entities than actually needed + const bool require_translation = (num_publishers >= 2 && num_subscribers >= 1) + || (num_publishers == 1 && num_subscribers_without_publisher >= 1); + if (require_translation) { + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + node->data().visited = true; + // Has subscriber(s)? + if (node->data().has_external_subscriber && !node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", node->data().topic_name.c_str(), node->data().version); + node->data().publication = node->data().publication_factory(_node); + } else if (!node->data().has_external_subscriber && node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", node->data().topic_name.c_str(), node->data().version); + node->data().publication.reset(); + } + // Has publisher(s)? + if (node->data().has_external_publisher && !node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", node->data().topic_name.c_str(), node->data().version); + node->data().subscription = node->data().subscription_factory(_node, [this, node_cpy=node]() { + onSubscriptionUpdate(node_cpy); + }); + } else if (!node->data().has_external_publisher && node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", node->data().topic_name.c_str(), node->data().version); + node->data().subscription.reset(); + } + }); + + } else { + // Reset any publishers or subscribers + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + node->data().visited = true; + if (node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "Removing publisher for topic '%s', version: %i", + node->data().topic_name.c_str(), node->data().version); + node->data().publication.reset(); + } + if (node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "Removing subscriber for topic '%s', version: %i", + node->data().topic_name.c_str(), node->data().version); + node->data().subscription.reset(); + } + }); + } + }); +} + +void PubSubGraph::onSubscriptionUpdate(const Graph::MessageNodePtr& node) { + _pub_sub_graph.translate( + node, + [this](const Graph::MessageNodePtr& node) { + if (node->data().publication != nullptr) { + const auto ret = rcl_publish(node->data().publication->get_publisher_handle().get(), + node->buffer().get(), nullptr); + if (ret != RCL_RET_OK) { + RCLCPP_WARN_ONCE(_node.get_logger(), "Failed to publish on topic '%s', version: %i", + node->data().topic_name.c_str(), node->data().version); + } + } + }); + +} + +void PubSubGraph::printTopicInfo(const std::unordered_map>& known_versions) const { + // Print info about known versions + RCLCPP_INFO(_node.get_logger(), "Registered pub/sub topics and versions:"); + for (const auto& [topic_name, version_set] : known_versions) { + if (version_set.empty()) { + continue; + } + const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(), + std::to_string(*version_set.begin()), // start with first element + [](std::string a, auto&& b) { + return std::move(a) + ", " + std::to_string(b); + }); + RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str()); + } +} + + +void PubSubGraph::handleLargestTopic(const std::unordered_map> &known_versions) { + // FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given + // type. The information that is relevant for us is the maximum serialized message size. + // Since different versions can have different sizes, we need to ensure the first publication or subscription + // happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered. + // And the type must continue to be in use (so we cannot delete it) + for (const auto& [topic_name, versions] : known_versions) { + size_t max_serialized_message_size = 0; + const PublicationFactoryCB* publication_factory_for_max = nullptr; + for (auto version : versions) { + const auto& node = _pub_sub_graph.findNode(MessageIdentifier{topic_name, version}); + assert(node); + const auto& node_data = node.value()->data(); + if (node_data.max_serialized_message_size > max_serialized_message_size) { + max_serialized_message_size = node_data.max_serialized_message_size; + publication_factory_for_max = &node_data.publication_factory; + } + } + if (publication_factory_for_max) { + _largest_topic_publications.emplace_back((*publication_factory_for_max)(_node)); + } + } +} diff --git a/msg/translation_node/src/pub_sub_graph.h b/msg/translation_node/src/pub_sub_graph.h new file mode 100644 index 0000000000..3dcc96fa9e --- /dev/null +++ b/msg/translation_node/src/pub_sub_graph.h @@ -0,0 +1,58 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include +#include "translations.h" +#include "translation_util.h" +#include "graph.h" + +class PubSubGraph { +public: + struct TopicInfo { + std::string topic_name; ///< fully qualified topic name (with namespace) + int num_subscribers; ///< does not include this node's subscribers + int num_publishers; ///< does not include this node's publishers + }; + + PubSubGraph(rclcpp::Node& node, const TopicTranslations& translations); + + void updateCurrentTopics(const std::vector& topics); + +private: + struct NodeDataPubSub { + explicit NodeDataPubSub(SubscriptionFactoryCB subscription_factory, PublicationFactoryCB publication_factory, + const MessageIdentifier& id, size_t max_serialized_message_size) + : subscription_factory(std::move(subscription_factory)), publication_factory(std::move(publication_factory)), + topic_name(id.topic_name), version(id.version), max_serialized_message_size(max_serialized_message_size) + { } + + const SubscriptionFactoryCB subscription_factory; + const PublicationFactoryCB publication_factory; + const std::string topic_name; + const MessageVersionType version; + const size_t max_serialized_message_size; + + // Keep track if there's currently a publisher/subscriber + bool has_external_publisher{false}; + bool has_external_subscriber{false}; + + rclcpp::SubscriptionBase::SharedPtr subscription; + rclcpp::PublisherBase::SharedPtr publication; + + bool visited{false}; + }; + + void onSubscriptionUpdate(const Graph::MessageNodePtr& node); + void printTopicInfo(const std::unordered_map>& known_versions) const; + void handleLargestTopic(const std::unordered_map>& known_versions); + + rclcpp::Node& _node; + Graph _pub_sub_graph; + std::unordered_map _known_topics_warned; + + std::vector _largest_topic_publications; +}; diff --git a/msg/translation_node/src/service_graph.cpp b/msg/translation_node/src/service_graph.cpp new file mode 100644 index 0000000000..66a389647c --- /dev/null +++ b/msg/translation_node/src/service_graph.cpp @@ -0,0 +1,230 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include "service_graph.h" + +#include + +using namespace std::chrono_literals; + +ServiceGraph::ServiceGraph(rclcpp::Node &node, const ServiceTranslations& translations) + : _node(node) { + + std::unordered_map> known_versions; + + for (const auto& service : translations.nodes()) { + const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), service.id.topic_name); + _known_services_warned.insert({full_topic_name, false}); + + const MessageIdentifier id{full_topic_name, service.id.version}; + auto node_data = std::make_shared(service, id); + _request_graph.addNodeIfNotExists(id, node_data, service.message_buffer_request); + _response_graph.addNodeIfNotExists(id, node_data, service.message_buffer_response); + known_versions[full_topic_name].insert(id.version); + } + + auto get_full_topic_names = [this](std::vector ids) { + for (auto& id : ids) { + id.topic_name = getFullTopicName(_node.get_effective_namespace(), id.topic_name); + } + return ids; + }; + + for (const auto& translation : translations.requestTranslations()) { + const std::vector inputs = get_full_topic_names(translation.inputs); + const std::vector outputs = get_full_topic_names(translation.outputs); + _request_graph.addTranslation(translation.cb, inputs, outputs); + } + for (const auto& translation : translations.responseTranslations()) { + const std::vector inputs = get_full_topic_names(translation.inputs); + const std::vector outputs = get_full_topic_names(translation.outputs); + _response_graph.addTranslation(translation.cb, inputs, outputs); + } + + printServiceInfo(known_versions); + handleLargestTopic(known_versions); + + _cleanup_timer = _node.create_wall_timer(10s, [this]() { + cleanupStaleRequests(); + }); +} + +void ServiceGraph::updateCurrentServices(const std::vector &services) { + _request_graph.iterateNodes([](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) { + node->data()->has_service = false; + node->data()->has_client = false; + node->data()->visited = false; + }); + + for (const auto& info : services) { + const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.service_name); + auto maybe_node = _request_graph.findNode({non_versioned_topic_name, version}); + if (!maybe_node) { + auto known_topic_iter = _known_services_warned.find(non_versioned_topic_name); + if (known_topic_iter != _known_services_warned.end() && !known_topic_iter->second) { + RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of service %s", version, non_versioned_topic_name.c_str()); + known_topic_iter->second = true; + } + continue; + } + const auto& node = maybe_node.value(); + + if (info.num_services > 0) { + node->data()->has_service = true; + } + if (info.num_clients > 0) { + node->data()->has_client = true; + } + } + + // Iterate connected graph segments + _request_graph.iterateNodes([this](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) { + if (node->data()->visited) { + return; + } + node->data()->visited = true; + + // Check if there's a reachable node with a service + int num_services = 0; + + _request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) { + if (node->data()->has_service && !node->data()->service) { + ++num_services; + } + }); + + // We need to instantiate a service and clients if there's exactly one external service. + if (num_services > 1 ) { + RCLCPP_ERROR_ONCE(_node.get_logger(), "Found %i services for service '%s', skipping this service", + num_services, node->data()->service_name.c_str()); + } else if (num_services == 1) { + _request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) { + node->data()->visited = true; + if (node->data()->has_service && !node->data()->client && !node->data()->service) { + RCLCPP_INFO(_node.get_logger(), "Found service for '%s', version: %i, adding client", node->data()->service_name.c_str(), node->data()->version); + auto tuple = node->data()->client_factory(_node, [this, tmp_node=node](rmw_request_id_t& request) { + onResponse(request, tmp_node); + }); + node->data()->client = std::get<0>(tuple); + node->data()->client_send_cb = std::get<1>(tuple); + + } else if (!node->data()->has_service && !node->data()->service && node->data()->has_client) { + RCLCPP_INFO(_node.get_logger(), "Found client for '%s', version: %i, adding service", node->data()->service_name.c_str(), node->data()->version); + node->data()->service = node->data()->service_factory(_node, [this, tmp_node=node](std::shared_ptr req_id) { + onNewRequest(std::move(req_id), tmp_node); + }); + } + }); + + } else { + // Reset any service or client + _request_graph.iterateBFS(node, [&](const GraphForService::MessageNodePtr& node) { + node->data()->visited = true; + if (node->data()->service) { + RCLCPP_INFO(_node.get_logger(), "Removing service for '%s', version: %i", + node->data()->service_name.c_str(), node->data()->version); + node->data()->service.reset(); + } + if (node->data()->client) { + RCLCPP_INFO(_node.get_logger(), "Removing client for '%s', version: %i", + node->data()->service_name.c_str(), node->data()->version); + node->data()->client.reset(); + } + }); + } + }); +} + +void ServiceGraph::printServiceInfo(const std::unordered_map>& known_versions) const { + // Print info about known versions + RCLCPP_INFO(_node.get_logger(), "Registered services and versions:"); + for (const auto& [topic_name, version_set] : known_versions) { + if (version_set.empty()) { + continue; + } + const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(), + std::to_string(*version_set.begin()), // start with first element + [](std::string a, auto&& b) { + return std::move(a) + ", " + std::to_string(b); + }); + RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str()); + } +} + +void ServiceGraph::handleLargestTopic(const std::unordered_map> &known_versions) { + // See PubSubGraph::handleLargestTopic for an explanation why this is needed + unsigned index = 0; + for (const auto& [topic_name, versions] : known_versions) { + std::array max_serialized_message_size{0, 0}; + std::array publication_factory_for_max{nullptr, nullptr}; + for (auto version : versions) { + const auto& node = _request_graph.findNode(MessageIdentifier{topic_name, version}); + assert(node); + const auto& node_data = node.value()->data(); + for (unsigned i = 0; i < max_serialized_message_size.size(); ++i) { + if (node_data->max_serialized_message_size[i] > max_serialized_message_size[i]) { + max_serialized_message_size[i] = node_data->max_serialized_message_size[i]; + publication_factory_for_max[i] = &node_data->publication_factory[i]; + } + } + } + for (unsigned i = 0; i < max_serialized_message_size.size(); ++i) { + if (publication_factory_for_max[i]) { + const std::string tmp_topic_name = "dummy_topic" + std::to_string(index++); + _largest_topic_publications.emplace_back((*publication_factory_for_max[i])(_node, tmp_topic_name)); + } + } + } +} + +void ServiceGraph::onNewRequest(std::shared_ptr req_id, GraphForService::MessageNodePtr node) { + bool service_called = false; + _request_graph.translate(node, [this, &service_called, &req_id, original_node=node](const GraphForService::MessageNodePtr& node) { + if (node->data()->client && node->data()->client_send_cb && !service_called) { + service_called = true; + const int64_t client_request_id = node->data()->client_send_cb(node->buffer()); + node->data()->ongoing_requests[client_request_id] = Request{req_id, original_node->data(), _node.now()}; + } + }); +} + +void ServiceGraph::onResponse(rmw_request_id_t &req_id, GraphForService::MessageNodePtr node) { + auto iter = node->data()->ongoing_requests.find(req_id.sequence_number); + if (iter == node->data()->ongoing_requests.end()) { + RCLCPP_ERROR(_node.get_logger(), "Got response with unknown request %li", req_id.sequence_number); + return; + } + bool service_called = false; + auto response_node = _response_graph.findNode({node->data()->service_name, node->data()->version}); + assert(response_node); + _response_graph.translate(response_node.value(), [this, &service_called, &iter](const GraphForService::MessageNodePtr &node) { + if (node->data()->service && !service_called && iter->second.original_node_data == node->data()) { + const rcl_ret_t ret = rcl_send_response(node->data()->service->get_service_handle().get(), + iter->second.original_request_id.get(), node->buffer().get()); + if (ret != RCL_RET_OK) { + RCLCPP_ERROR(_node.get_logger(), "Failed to send response: %s", rcl_get_error_string().str); + } + service_called = true; + } + }); + + node->data()->ongoing_requests.erase(iter); +} + +void ServiceGraph::cleanupStaleRequests() { + static const auto kRequestTimeout = 20s; + _request_graph.iterateNodes([this](const MessageIdentifier& type, const GraphForService::MessageNodePtr& node) { + for (auto it = node->data()->ongoing_requests.begin(); it != node->data()->ongoing_requests.end();) { + const auto& request = it->second; + if (_node.now() - request.timestamp_received > kRequestTimeout) { + RCLCPP_INFO(_node.get_logger(), "Request timed out, dropping ongoing request for '%s', version: %i, request id: %li", + node->data()->service_name.c_str(), node->data()->version, request.original_request_id->sequence_number); + it = node->data()->ongoing_requests.erase(it); + } else { + ++it; + } + } + }); +} diff --git a/msg/translation_node/src/service_graph.h b/msg/translation_node/src/service_graph.h new file mode 100644 index 0000000000..f290bc1f35 --- /dev/null +++ b/msg/translation_node/src/service_graph.h @@ -0,0 +1,76 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include +#include "translations.h" +#include "translation_util.h" +#include "graph.h" + +class ServiceGraph { +public: + struct ServiceInfo { + std::string service_name; ///< fully qualified service name (with namespace) + int num_services; ///< This can include a service created by the translation node + int num_clients; ///< This can include a client created by the translation node + }; + + ServiceGraph(rclcpp::Node &node, const ServiceTranslations& translations); + + void updateCurrentServices(const std::vector& services); + +private: + struct NodeDataService; + using GraphForService = Graph>; + + void printServiceInfo(const std::unordered_map> &known_versions) const; + void handleLargestTopic(const std::unordered_map>& known_versions); + + void onNewRequest(std::shared_ptr req_id, GraphForService::MessageNodePtr node); + void onResponse(rmw_request_id_t& req_id, GraphForService::MessageNodePtr node); + void cleanupStaleRequests(); + + struct Request { + std::shared_ptr original_request_id; + std::shared_ptr original_node_data{nullptr}; + rclcpp::Time timestamp_received; + }; + struct NodeDataService { + explicit NodeDataService(const Service& service, const MessageIdentifier& id) + : service_factory(service.service_factory), client_factory(service.client_factory), + service_name(id.topic_name), version(id.version), + publication_factory{service.publication_factory_request, service.publication_factory_response}, + max_serialized_message_size{service.max_serialized_message_size_request, service.max_serialized_message_size_response} + { } + + const ServiceFactoryCB service_factory; + const ClientFactoryCB client_factory; + const std::string service_name; + const MessageVersionType version; + const std::array publication_factory; // Request/Response + const std::array max_serialized_message_size; + + // Keep track if there's currently a client/service + bool has_service{false}; + bool has_client{false}; + + rclcpp::ClientBase::SharedPtr client; + ClientSendCB client_send_cb; + rclcpp::ServiceBase::SharedPtr service; + + std::unordered_map ongoing_requests; ///< Ongoing service calls for this node + + bool visited{false}; + }; + + rclcpp::Node& _node; + GraphForService _request_graph; + GraphForService _response_graph; + std::unordered_map _known_services_warned; + rclcpp::TimerBase::SharedPtr _cleanup_timer; + + std::vector _largest_topic_publications; +}; diff --git a/msg/translation_node/src/template_util.h b/msg/translation_node/src/template_util.h new file mode 100644 index 0000000000..8566f34d5e --- /dev/null +++ b/msg/translation_node/src/template_util.h @@ -0,0 +1,64 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include +#include +#include + +/** + * Helper struct to store template parameter packs + */ +template +struct Pack { +}; + +/** + * Struct for a template parameter pack with access to the individual types + */ +template +struct TypesArray { + template + struct TypeHelper { + using Type = T; + using Next = TypeHelper; + }; + + using Type1 = typename TypeHelper::Type; + using Type2 = typename TypeHelper::Next::Type; + using Type3 = typename TypeHelper::Next::Next::Type; + using Type4 = typename TypeHelper::Next::Next::Next::Type; + using Type5 = typename TypeHelper::Next::Next::Next::Next::Type; + using Type6 = typename TypeHelper::Next::Next::Next::Next::Next::Type; + + using args = Pack; +}; + + +/** + * Helper for call_translation_function() + */ +template +inline void call_translation_function_impl(F f, Pack, Pack, + const std::vector>& messages_in, + std::vector>& messages_out, + std::integer_sequence, std::integer_sequence) +{ + f(*static_cast(messages_in[Is].get())..., *static_cast(messages_out[Os].get())...); +} + +/** + * Call a translation function F which takes the arguments (const ArgsIn&..., ArgsOut&...), + * by passing messages_in and messages_out as arguments. + * Note that sizeof(ArgsIn) == messages_in.length() && sizeof(ArgsOut) == messages_out.length() must hold. + */ +template +inline void call_translation_function(F f, Pack pack_in, Pack pack_out, + const std::vector>& messages_in, + std::vector>& messages_out) { + call_translation_function_impl(f, pack_in, pack_out, messages_in, messages_out, + std::index_sequence_for{}, std::index_sequence_for{}); +} diff --git a/msg/translation_node/src/translation_util.h b/msg/translation_node/src/translation_util.h new file mode 100644 index 0000000000..c5ac4f8153 --- /dev/null +++ b/msg/translation_node/src/translation_util.h @@ -0,0 +1,386 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include "translations.h" +#include "util.h" +#include "template_util.h" + +#include +#include + +class RegisteredTranslations { +public: + + RegisteredTranslations(RegisteredTranslations const&) = delete; + void operator=(RegisteredTranslations const&) = delete; + + + static RegisteredTranslations& instance() { + static RegisteredTranslations instance; + return instance; + } + + /** + * @brief Register a translation class with 1 input and 1 output message. + * + * The translation class has the form: + * + * ``` + * class MyTranslation { + * public: + * using MessageOlder = px4_msgs_old::msg::VehicleAttitudeV2; + * + * using MessageNewer = px4_msgs::msg::VehicleAttitude; + * + * static constexpr const char* kTopic = "fmu/out/vehicle_attitude"; + * + * static void fromOlder(const MessageOlder &msg_older, MessageNewer &msg_newer) { + * // set msg_newer from msg_older + * } + * + * static void toOlder(const MessageNewer &msg_newer, MessageOlder &msg_older) { + * // set msg_older from msg_newer + * } + * }; + * ``` + */ + template + void registerDirectTranslation() { + const std::string topic_name = T::kTopic; + _topic_translations.addTopic(getTopicForMessageType(topic_name)); + _topic_translations.addTopic(getTopicForMessageType(topic_name)); + + // Translation callbacks + auto translation_cb_from_older = [](const std::vector& older_msg, std::vector& newer_msg) { + T::fromOlder(*(const typename T::MessageOlder*)older_msg[0].get(), *(typename T::MessageNewer*)newer_msg[0].get()); + }; + auto translation_cb_to_older = [](const std::vector& newer_msg, std::vector& older_msg) { + T::toOlder(*(const typename T::MessageNewer*)newer_msg[0].get(), *(typename T::MessageOlder*)older_msg[0].get()); + }; + _topic_translations.addTranslation({translation_cb_from_older, + {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}}); + _topic_translations.addTranslation({translation_cb_to_older, + {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}}); + } + + /** + * @brief Register a translation class for a service. + * + * The translation class has the form: + * + * ``` + * class MyServiceTranslation { + * public: + * using MessageOlder = px4_msgs_old::srv::VehicleCommandV0; + * using MessageNewer = px4_msgs::srv::VehicleCommand; + * + * static constexpr const char* kTopic = "fmu/vehicle_command"; + * + * static void fromOlder(const MessageOlder::Request &msg_older, MessageNewer::Request &msg_newer) { + * // set msg_newer from msg_older + * } + * + * static void toOlder(const MessageNewer::Request &msg_newer, MessageOlder::Request &msg_older) { + * // set msg_older from msg_newer + * } + * + * static void fromOlder(const MessageOlder::Response &msg_older, MessageNewer::Response &msg_newer) { + * // set msg_newer from msg_older + * } + * + * static void toOlder(const MessageNewer::Response &msg_newer, MessageOlder::Response &msg_older) { + * // set msg_older from msg_newer + * } + * }; + * ``` + */ + template + void registerServiceDirectTranslation() { + const std::string topic_name = T::kTopic; + _service_translations.addNode(getServiceForMessageType(topic_name)); + _service_translations.addNode(getServiceForMessageType(topic_name)); + // Add translations + { // Request + auto translation_cb_from_older = [](const std::vector &older_msg, + std::vector &newer_msg) { + T::fromOlder(*(const typename T::MessageOlder::Request *) older_msg[0].get(), + *(typename T::MessageNewer::Request *) newer_msg[0].get()); + }; + auto translation_cb_to_older = [](const std::vector &newer_msg, + std::vector &older_msg) { + T::toOlder(*(const typename T::MessageNewer::Request *) newer_msg[0].get(), + *(typename T::MessageOlder::Request *) older_msg[0].get()); + }; + _service_translations.addRequestTranslation({translation_cb_from_older, + {MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}}); + _service_translations.addRequestTranslation({translation_cb_to_older, + {MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}}); + } + { // Response + auto translation_cb_from_older = [](const std::vector &older_msg, + std::vector &newer_msg) { + T::fromOlder(*(const typename T::MessageOlder::Response *) older_msg[0].get(), + *(typename T::MessageNewer::Response *) newer_msg[0].get()); + }; + auto translation_cb_to_older = [](const std::vector &newer_msg, + std::vector &older_msg) { + T::toOlder(*(const typename T::MessageNewer::Response *) newer_msg[0].get(), + *(typename T::MessageOlder::Response *) older_msg[0].get()); + }; + _service_translations.addResponseTranslation({translation_cb_from_older, + {MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}}); + _service_translations.addResponseTranslation({translation_cb_to_older, + {MessageIdentifier{topic_name, T::MessageNewer::Request::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageOlder::Request::MESSAGE_VERSION}}}); + } + + } + + /** + * @brief Register a translation class with N input and M output messages. + * + * The translation class has the form: + * ``` + * class MyTranslation { + * public: + * using MessagesOlder = TypesArray; + * static constexpr const char* kTopicsOlder[] = { + * "fmu/out/vehicle_global_position", + * "fmu/out/vehicle_local_position", + * ... + * }; + * + * using MessagesNewer = TypesArray; + * static constexpr const char* kTopicsNewer[] = { + * "fmu/out/vehicle_global_position", + * "fmu/out/vehicle_local_position", + * ... + * }; + * + * static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, ... + * MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2, ...) { + * // Set msg_newerX from msg_olderX + * } + * + * static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, ... + * MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2, ...) { + * // Set msg_olderX from msg_newerX + * } + * }; + * ``` + */ + template + void registerTranslation() { + const auto topics_older = getTopicsForMessageType(typename T::MessagesOlder::args(), T::kTopicsOlder); + std::vector topics_older_identifiers; + for (const auto& topic : topics_older) { + _topic_translations.addTopic(topic); + topics_older_identifiers.emplace_back(topic.id); + } + const auto topics_newer = getTopicsForMessageType(typename T::MessagesNewer::args(),T::kTopicsNewer); + std::vector topics_newer_identifiers; + for (const auto& topic : topics_newer) { + _topic_translations.addTopic(topic); + topics_newer_identifiers.emplace_back(topic.id); + } + + // Translation callbacks + const auto translation_cb_from_older = [](const std::vector& older_msgs, std::vector& newer_msgs) { + call_translation_function(&T::fromOlder, typename T::MessagesOlder::args(), typename T::MessagesNewer::args(), older_msgs, newer_msgs); + }; + const auto translation_cb_to_older = [](const std::vector& newer_msgs, std::vector& older_msgs) { + call_translation_function(&T::toOlder, typename T::MessagesNewer::args(), typename T::MessagesOlder::args(), newer_msgs, older_msgs); + }; + { + // Older -> Newer + Translation translation; + translation.cb = translation_cb_from_older; + translation.inputs = topics_older_identifiers; + translation.outputs = topics_newer_identifiers; + _topic_translations.addTranslation(std::move(translation)); + } + { + // Newer -> Older + Translation translation; + translation.cb = translation_cb_to_older; + translation.inputs = topics_newer_identifiers; + translation.outputs = topics_older_identifiers; + _topic_translations.addTranslation(std::move(translation)); + } + } + + const TopicTranslations& topicTranslations() const { return _topic_translations; } + const ServiceTranslations& serviceTranslations() const { return _service_translations; } + +protected: + RegisteredTranslations() = default; +private: + template + static size_t getMaxSerializedMessageSize() { + const auto type_handle = rclcpp::get_message_type_support_handle(); + const auto fastrtps_handle = rosidl_typesupport_cpp::get_message_typesupport_handle_function(&type_handle, "rosidl_typesupport_fastrtps_cpp"); + if (fastrtps_handle) { + const auto *callbacks = static_cast(fastrtps_handle->data); + char bound_info; + return callbacks->max_serialized_size(bound_info); + } + return 0; + } + + template + static Topic getTopicForMessageType(const std::string& topic_name) { + Topic ret{}; + ret.id.topic_name = topic_name; + ret.id.version = RosMessageType::MESSAGE_VERSION; + auto message_buffer = std::make_shared(); + ret.message_buffer = std::static_pointer_cast(message_buffer); + + // Subscription/Publication factory methods + const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.id.version); + ret.subscription_factory = [topic_name_versioned, message_buffer](rclcpp::Node& node, + const std::function& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr { + return std::dynamic_pointer_cast( + // Note: template instantiation of subscriptions slows down compilation considerably, see + // https://github.com/ros2/rclcpp/issues/1949 + node.create_subscription(topic_name_versioned, rclcpp::QoS(1).best_effort(), + [on_topic_cb=on_topic_cb, message_buffer](typename RosMessageType::UniquePtr msg) -> void { + *message_buffer = *msg; + on_topic_cb(); + })); + }; + ret.publication_factory = [topic_name_versioned](rclcpp::Node& node) -> rclcpp::PublisherBase::SharedPtr { + return std::dynamic_pointer_cast( + node.create_publisher(topic_name_versioned, rclcpp::QoS(1).best_effort())); + }; + + ret.max_serialized_message_size = getMaxSerializedMessageSize(); + + return ret; + } + + template + static Service getServiceForMessageType(const std::string& topic_name) { + Service ret{}; + ret.id.topic_name = topic_name; + ret.id.version = RosMessageType::Request::MESSAGE_VERSION; + auto message_buffer_request = std::make_shared(); + ret.message_buffer_request = std::static_pointer_cast(message_buffer_request); + auto message_buffer_response = std::make_shared(); + ret.message_buffer_response = std::static_pointer_cast(message_buffer_response); + + // Service/client factory methods + const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.id.version); + ret.service_factory = [topic_name_versioned, message_buffer_request](rclcpp::Node& node, + const std::function req_id)>& on_request_cb) -> rclcpp::ServiceBase::SharedPtr { + return std::dynamic_pointer_cast( + node.create_service(topic_name_versioned, + [on_request_cb=on_request_cb, message_buffer_request]( + typename rclcpp::Service::SharedPtr service, + std::shared_ptr req_id, + const std::shared_ptr request + ) -> void { + *message_buffer_request = *request; + on_request_cb(std::move(req_id)); + })); + }; + ret.client_factory = [topic_name_versioned, message_buffer_response](rclcpp::Node& node, + const std::function& on_response_cb) { + auto client = node.create_client(topic_name_versioned); + client->set_on_new_response_callback([client, message_buffer_response, on_response_cb](size_t num) { + for (size_t i = 0; i < num; i++) { + rmw_request_id_t request_id{}; + if (client->take_response(*message_buffer_response, request_id)) { + on_response_cb(request_id); + } + } + }); + const auto send_request = [client](MessageBuffer request) { + auto result = client->async_send_request(std::static_pointer_cast(request)); + // We don't need the client to keep track of ongoing requests, so we remove it right away + // to prevent leaks + client->remove_pending_request(result.request_id); + return result.request_id; + }; + return std::make_tuple(std::dynamic_pointer_cast(client), send_request); + }; + + ret.publication_factory_request = [](rclcpp::Node& node, const std::string& topic_name) -> rclcpp::PublisherBase::SharedPtr { + return std::dynamic_pointer_cast( + node.create_publisher( + topic_name,rclcpp::QoS(1).best_effort().avoid_ros_namespace_conventions(true))); + }; + ret.publication_factory_response = [](rclcpp::Node& node, const std::string& topic_name) -> rclcpp::PublisherBase::SharedPtr { + return std::dynamic_pointer_cast( + node.create_publisher( + topic_name,rclcpp::QoS(1).best_effort().avoid_ros_namespace_conventions(true))); + }; + + ret.max_serialized_message_size_request = getMaxSerializedMessageSize(); + ret.max_serialized_message_size_response = getMaxSerializedMessageSize(); + + return ret; + } + + template + static std::vector getTopicsForMessageTypeImpl(const char* const topics[], std::integer_sequence) { + std::vector ret { + getTopicForMessageType(topics[Is])... + }; + return ret; + } + + template + static std::vector getTopicsForMessageType(Pack, const char* const (&topics)[N]) { + static_assert(N == sizeof...(RosMessageTypes), "Number of topics does not match number of message types"); + return getTopicsForMessageTypeImpl(topics, std::index_sequence_for{}); + } + + TopicTranslations _topic_translations; + ServiceTranslations _service_translations; +}; + +template +class RegistrationHelperDirect { +public: + explicit RegistrationHelperDirect(const char* dummy) { + // There's something strange: when there is no argument passed, the + // compiler removes the static object completely. I don't know + // why but this dummy variable prevents that. + (void)dummy; + RegisteredTranslations::instance().registerDirectTranslation(); + } + explicit RegistrationHelperDirect(const char* dummy, bool for_service) { + (void)dummy; + RegisteredTranslations::instance().registerServiceDirectTranslation(); + } + RegistrationHelperDirect(RegistrationHelperDirect const&) = delete; + void operator=(RegistrationHelperDirect const&) = delete; +}; + +#define REGISTER_TOPIC_TRANSLATION_DIRECT(class_name) \ + RegistrationHelperDirect class_name##_registration_direct("dummy"); + +#define REGISTER_SERVICE_TRANSLATION_DIRECT(class_name) \ + RegistrationHelperDirect class_name##_service_registration_direct("dummy", true); + +template +class TopicRegistrationHelperGeneric { +public: + explicit TopicRegistrationHelperGeneric(const char* dummy) { + (void)dummy; + RegisteredTranslations::instance().registerTranslation(); + } + TopicRegistrationHelperGeneric(TopicRegistrationHelperGeneric const&) = delete; + void operator=(TopicRegistrationHelperGeneric const&) = delete; +}; + +#define REGISTER_TOPIC_TRANSLATION(class_name) \ + TopicRegistrationHelperGeneric class_name##_registration_generic("dummy"); diff --git a/msg/translation_node/src/translations.cpp b/msg/translation_node/src/translations.cpp new file mode 100644 index 0000000000..0b7d8b0c93 --- /dev/null +++ b/msg/translation_node/src/translations.cpp @@ -0,0 +1,5 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include "translations.h" diff --git a/msg/translation_node/src/translations.h b/msg/translation_node/src/translations.h new file mode 100644 index 0000000000..3982d6a1c0 --- /dev/null +++ b/msg/translation_node/src/translations.h @@ -0,0 +1,91 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.h" +#include "graph.h" + +#include + + +using TranslationCB = std::function&, std::vector&)>; +using SubscriptionFactoryCB = std::function& on_topic_cb)>; +using PublicationFactoryCB = std::function; +using NamedPublicationFactoryCB = std::function; +using ServiceFactoryCB = std::function req_id)>& on_request_cb)>; +using ClientSendCB = std::function; +using ClientFactoryCB = std::function(rclcpp::Node&, const std::function& on_response_cb)>; + +struct Topic { + MessageIdentifier id; + + SubscriptionFactoryCB subscription_factory; + PublicationFactoryCB publication_factory; + + std::shared_ptr message_buffer; + size_t max_serialized_message_size{}; +}; + +struct Service { + MessageIdentifier id; + + ServiceFactoryCB service_factory; + ClientFactoryCB client_factory; + + NamedPublicationFactoryCB publication_factory_request; + NamedPublicationFactoryCB publication_factory_response; + + std::shared_ptr message_buffer_request; + size_t max_serialized_message_size_request{}; + + std::shared_ptr message_buffer_response; + size_t max_serialized_message_size_response{}; +}; + +struct Translation { + TranslationCB cb; + std::vector inputs; + std::vector outputs; +}; + +class TopicTranslations { +public: + TopicTranslations() = default; + + void addTopic(Topic topic) { _topics.push_back(std::move(topic)); } + void addTranslation(Translation translation) { _translations.push_back(std::move(translation)); } + + const std::vector& topics() const { return _topics; } + const std::vector& translations() const { return _translations; } +private: + std::vector _topics; + std::vector _translations; +}; + +class ServiceTranslations { +public: + ServiceTranslations() = default; + + void addNode(Service node) { _nodes.push_back(std::move(node)); } + void addRequestTranslation(Translation translation) { _request_translations.push_back(std::move(translation)); } + void addResponseTranslation(Translation translation) { _response_translations.push_back(std::move(translation)); } + + const std::vector& nodes() const { return _nodes; } + const std::vector& requestTranslations() const { return _request_translations; } + const std::vector& responseTranslations() const { return _response_translations; } +private: + std::vector _nodes; + std::vector _request_translations; + std::vector _response_translations; +}; diff --git a/msg/translation_node/src/util.h b/msg/translation_node/src/util.h new file mode 100644 index 0000000000..8d584598e9 --- /dev/null +++ b/msg/translation_node/src/util.h @@ -0,0 +1,51 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include + +using MessageVersionType = uint32_t; + +static inline std::string getVersionedTopicName(const std::string& topic_name, MessageVersionType version) { + // version == 0 can be used to transition from non-versioned topics to versioned ones + if (version == 0) { + return topic_name; + } + return topic_name + "_v" + std::to_string(version); +} + +static inline std::pair getNonVersionedTopicName(const std::string& topic_name) { + // topic name has the form _v, or just (with version=0) + auto pos = topic_name.find_last_of("_v"); + // Ensure there's at least one more char after the found string + if (pos == std::string::npos || pos + 2 > topic_name.length()) { + return std::make_pair(topic_name, 0); + } + std::string non_versioned_topic_name = topic_name.substr(0, pos - 1); + std::string version = topic_name.substr(pos + 1); + // Ensure only digits are in the version string + for (char c : version) { + if (!std::isdigit(c)) { + return std::make_pair(topic_name, 0); + } + } + return std::make_pair(non_versioned_topic_name, std::stol(version)); +} + +/** + * Get the full topic name, including namespace from a topic name. + * namespace_name should be set to Node::get_effective_namespace() + */ +static inline std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) { + std::string full_topic_name = topic_name; + if (!full_topic_name.empty() && full_topic_name[0] != '/') { + if (namespace_name.empty() || namespace_name.back() != '/') { + full_topic_name = '/' + full_topic_name; + } + full_topic_name = namespace_name + full_topic_name; + } + return full_topic_name; +} diff --git a/msg/translation_node/test/graph.cpp b/msg/translation_node/test/graph.cpp new file mode 100644 index 0000000000..e0d1f360cb --- /dev/null +++ b/msg/translation_node/test/graph.cpp @@ -0,0 +1,623 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include + + +TEST(graph, basic) +{ + struct NodeData { + bool iterated{false}; + bool translated{false}; + }; + Graph graph; + + const int32_t message1_value = 3; + const int32_t offset = 4; + + // Add 2 nodes + const MessageIdentifier id1{"topic_name", 1}; + auto buffer1 = std::make_shared(); + *buffer1 = message1_value; + EXPECT_TRUE(graph.addNodeIfNotExists(id1, {}, buffer1)); + EXPECT_FALSE(graph.addNodeIfNotExists(id1, {}, std::make_shared())); + const MessageIdentifier id2{"topic_name", 4}; + auto buffer2 = std::make_shared(); + *buffer2 = 773; + EXPECT_TRUE(graph.addNodeIfNotExists(id2, {}, buffer2)); + + // Search nodes + EXPECT_TRUE(graph.findNode(id1).has_value()); + EXPECT_TRUE(graph.findNode(id2).has_value()); + + // Add 1 translation + auto translation_cb = [&offset](const std::vector& a, std::vector& b) { + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value + offset; + }; + graph.addTranslation(translation_cb, {id1}, {id2}); + + // Iteration from id1 must reach id2 + auto node1 = graph.findNode(id1).value(); + auto node2 = graph.findNode(id2).value(); + auto iterate_cb = [](const Graph::MessageNodePtr& node) { + node->data().iterated = true; + }; + graph.iterateBFS(node1, iterate_cb); + EXPECT_TRUE(node1->data().iterated); + EXPECT_TRUE(node2->data().iterated); + node1->data().iterated = false; + node2->data().iterated = false; + + // Iteration from id2 must not reach id1 + graph.iterateBFS(node2, iterate_cb); + EXPECT_FALSE(node1->data().iterated); + EXPECT_TRUE(node2->data().iterated); + + // Test translation + graph.translate(node1, + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + EXPECT_FALSE(node1->data().translated); + EXPECT_TRUE(node2->data().translated); + EXPECT_EQ(*buffer1, message1_value); + EXPECT_EQ(*buffer2, message1_value + offset); +} + + +TEST(graph, multi_path) +{ + // Multiple paths with cycles + struct NodeData { + unsigned iterated_idx{0}; + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 6; + std::array ids{{ + {"topic_name", 1}, + {"topic_name", 2}, + {"topic_name", 3}, + {"topic_name", 4}, + {"topic_name", 5}, + {"topic_name", 6}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + // Translations + std::bitset<32> translated; + + auto get_translation_cb = [&translated](unsigned bit) { + auto translation_cb = [&translated, bit](const std::vector &a, std::vector &b) { + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value | (1 << bit); + translated.set(bit); + }; + return translation_cb; + }; + + // Graph: + // ___ 2 -- 3 -- 4 + // | | + // 1 _______| + // | + // 5 + // | + // 6 + + unsigned next_bit = 0; + // Connect each node to the previous and next, except the last 3 + for (unsigned i=0; i < num_nodes - 3; ++i) { + graph.addTranslation(get_translation_cb(next_bit++), {ids[i]}, {ids[i+1]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[i+1]}, {ids[i]}); + } + + // Connect the first to the 3rd as well + graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[2]}, {ids[0]}); + + // Connect the second last to the first one + graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[num_nodes-2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[0]}); + + // Connect the second last to the last one + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-1]}, {ids[num_nodes-2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[num_nodes-1]}); + + unsigned iteration_idx = 1; + graph.iterateBFS(graph.findNode(ids[0]).value(), [&iteration_idx](const Graph::MessageNodePtr& node) { + assert(node->data().iterated_idx == 0); + node->data().iterated_idx = iteration_idx++; + }); + + EXPECT_EQ(graph.findNode(ids[0]).value()->data().iterated_idx, 1); + // We're a bit stricter than we would have to be: ids[1,2,4] would be allowed to have any of the values (2,3,4) + EXPECT_EQ(graph.findNode(ids[1]).value()->data().iterated_idx, 2); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().iterated_idx, 3); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().iterated_idx, 4); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().iterated_idx, 5); + + + // Translation + graph.translate(graph.findNode(ids[0]).value(), + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + + // All nodes should be translated except the first + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + for (unsigned i = 1; i < num_nodes; ++i) { + EXPECT_EQ(graph.findNode(ids[i]).value()->data().translated, true) << "node[" << i << "]"; + } + + // Ensure the correct edges were used for translations + EXPECT_EQ("00000000000000000000100101010001", translated.to_string()); + + // Ensure correct translation path taken for each node (which is stored in the buffers), + // and translation callback got called + EXPECT_EQ(*buffers[0], 0); + EXPECT_EQ(*buffers[1], 0b1); + EXPECT_EQ(*buffers[2], 0b1000000); + EXPECT_EQ(*buffers[3], 0b1010000); + EXPECT_EQ(*buffers[4], 0b100000000); + EXPECT_EQ(*buffers[5], 0b100100000000); + + for (unsigned i=0; i < num_nodes; ++i) { + printf("node[%i]: translated: %i, buffer: %i\n", i, graph.findNode(ids[i]).value()->data().translated, + *buffers[i]); + } +} + +TEST(graph, multi_links) { + // Multiple topics (merging / splitting) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 6; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic1", 2}, + {"topic3", 1}, + {"topic4", 1}, + {"topic1", 3}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ + // 1 - | | --- + // | | - 3 - | | - 6 + // 2 - | | --- + // | --- + // | ___ + // --- | | - 4 + // | | - 5 + // --- + + // Translations + auto translation_cb_merge = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value1 | *a_value2; + }; + auto translation_cb_split = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 2); + auto a_value = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value & 0x0000ffffu; + *b_value2 = *a_value & 0xffff0000u; + }; + auto translation_cb_direct = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 1); + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value; + }; + + auto addTranslation = [&](const std::vector& inputs, const std::vector& outputs) { + assert(inputs.size() <= 2); + assert(outputs.size() <= 2); + if (inputs.size() == 1) { + if (outputs.size() == 1) { + graph.addTranslation(translation_cb_direct, inputs, outputs); + graph.addTranslation(translation_cb_direct, outputs, inputs); + } else { + graph.addTranslation(translation_cb_split, inputs, outputs); + graph.addTranslation(translation_cb_merge, outputs, inputs); + } + } else { + assert(outputs.size() == 1); + graph.addTranslation(translation_cb_merge, inputs, outputs); + graph.addTranslation(translation_cb_split, outputs, inputs); + } + }; + addTranslation({ids[0], ids[1]}, {ids[2]}); + addTranslation({ids[1]}, {ids[3], ids[4]}); + addTranslation({ids[2]}, {ids[5]}); + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating node 2 should trigger an output for nodes 4+5 (splitting) + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0f00000f; + translate_node(ids[1]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + EXPECT_EQ(*buffers[3], 0x0000000f); + EXPECT_EQ(*buffers[4], 0x0f000000); + + reset_translated(); + + // Now updating node 1 should update nodes 3+6 (merging, both inputs available now) + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(*buffers[2], 0xaf0000bf); + EXPECT_EQ(*buffers[5], 0xaf0000bf); + + reset_translated(); + + // Another update must not trigger any other updates + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + + reset_translated(); + + // Backwards: updating node 6 should trigger updates for 1+2, but also 4+5 + *buffers[5] = 0xc00000d0; + translate_node(ids[5]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + EXPECT_EQ(*buffers[0], 0x000000d0); + EXPECT_EQ(*buffers[1], 0xc0000000); + EXPECT_EQ(*buffers[2], 0xc00000d0); + EXPECT_EQ(*buffers[3], 0); + EXPECT_EQ(*buffers[4], 0xc0000000); + EXPECT_EQ(*buffers[5], 0xc00000d0); +} + +TEST(graph, multi_links2) { + // Multiple topics (merging / splitting) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 8; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic3", 1}, + {"topic1", 2}, + {"topic2", 2}, + {"topic1", 3}, + {"topic2", 3}, + {"topic3", 3}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ ___ + // 1 - | | | | - 6 + // | | - 4 - | | + // 2 - | | | | - 7 + // | | - 5 - | | + // 3 - | | | | - 8 + // --- --- + + // Translations + auto translation_cb_32 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 3); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto a_value3 = static_cast(a[2].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1 | *a_value2; + *b_value2 = *a_value3; + }; + auto translation_cb_23 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 3); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + auto b_value3 = static_cast(b[2].get()); + *b_value1 = *a_value1 & 0x0000ffffu; + *b_value2 = *a_value1 & 0xffff0000u; + *b_value3 = *a_value2; + }; + graph.addTranslation(translation_cb_32, {ids[0], ids[1], ids[2]}, {ids[3], ids[4]}); + graph.addTranslation(translation_cb_23, {ids[3], ids[4]}, {ids[0], ids[1], ids[2]}); + + graph.addTranslation(translation_cb_23, {ids[3], ids[4]}, {ids[5], ids[6], ids[7]}); + graph.addTranslation(translation_cb_32, {ids[5], ids[6], ids[7]}, {ids[3], ids[4]}); + + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating nodes 1+2+3 should update nodes 6+7+8 + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0f00000f; + *buffers[2] = 0x0c00000c; + translate_node(ids[1]); + translate_node(ids[0]); + translate_node(ids[2]); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[7]).value()->data().translated, true); + EXPECT_EQ(*buffers[3], 0xa00000b0 | 0x0f00000f); + EXPECT_EQ(*buffers[4], 0x0c00000c); + EXPECT_EQ(*buffers[5], (0xa00000b0 | 0x0f00000f) & 0x0000ffffu); + EXPECT_EQ(*buffers[6], (0xa00000b0 | 0x0f00000f) & 0xffff0000u); + EXPECT_EQ(*buffers[7], 0x0c00000c); + + reset_translated(); + + // Now updating nodes 6+7+8 should update nodes 1+2+3 + *buffers[5] = 0xa00000b0; + *buffers[6] = 0x0f00000f; + *buffers[7] = 0x0c00000c; + translate_node(ids[5]); + translate_node(ids[6]); + translate_node(ids[7]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(*buffers[3], 0xa00000b0 | 0x0f00000f); + EXPECT_EQ(*buffers[4], 0x0c00000c); + EXPECT_EQ(*buffers[0], (0xa00000b0 | 0x0f00000f) & 0x0000ffffu); + EXPECT_EQ(*buffers[1], (0xa00000b0 | 0x0f00000f) & 0xffff0000u); + EXPECT_EQ(*buffers[2], 0x0c00000c); +} + +TEST(graph, multi_links3) { + // Multiple topics (cannot use the shortest path) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 7; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic1", 2}, + {"topic1", 3}, + {"topic1", 4}, + {"topic2", 4}, + {"topic1", 5}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ ___ ___ ___ + // 1 - | | - 3 - | | - 4 - | | - 5 - | | - 7 + // | | --- --- | | + // | | | | + // 2 - | | --------------------- 6 - | | + // --- --- + + // Translations + auto translation_cb_21 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + *b_value1 = *a_value1 | *a_value2; + }; + auto translation_cb_22 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1; + *b_value2 = *a_value2; + }; + auto translation_cb_12 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1 & 0x0000ffffu; + *b_value2 = *a_value1 & 0xffff0000u; + }; + auto translation_cb_11 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + *b_value1 = *a_value1 + 1; + }; + graph.addTranslation(translation_cb_22, {ids[0], ids[1]}, {ids[2], ids[5]}); + graph.addTranslation(translation_cb_22, {ids[2], ids[5]}, {ids[0], ids[1]}); + graph.addTranslation(translation_cb_11, {ids[2]}, {ids[3]}); + graph.addTranslation(translation_cb_11, {ids[3]}, {ids[2]}); + graph.addTranslation(translation_cb_11, {ids[3]}, {ids[4]}); + graph.addTranslation(translation_cb_11, {ids[4]}, {ids[3]}); + graph.addTranslation(translation_cb_21, {ids[4], ids[5]}, {ids[6]}); + graph.addTranslation(translation_cb_12, {ids[6]}, {ids[4], ids[5]}); + + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [](auto&& node) { + assert(!node->data().translated); + assert(!node->data().translated); + node->data().translated = true; + }); + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating nodes 1+2 should update node 7 + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0a00000b; + translate_node(ids[1]); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, false); + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(*buffers[2], 0xa00000b0); + EXPECT_EQ(*buffers[3], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[4], 0xa00000b0 + 2); + EXPECT_EQ(*buffers[5], 0x0a00000b); + EXPECT_EQ(*buffers[6], ((0xa00000b0 + 2) | 0x0a00000b)); + + reset_translated(); + + // Now updating nodes 4+6 should update the rest + *buffers[3] = 0xa00000b0; + *buffers[5] = 0x0f00000f; + translate_node(ids[3]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, false); + translate_node(ids[5]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(*buffers[0], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[1], 0x0f00000f); + EXPECT_EQ(*buffers[2], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[4], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[6], (0xa00000b0 + 1) | 0x0f00000f); +} diff --git a/msg/translation_node/test/main.cpp b/msg/translation_node/test/main.cpp new file mode 100644 index 0000000000..9d7e484c70 --- /dev/null +++ b/msg/translation_node/test/main.cpp @@ -0,0 +1,16 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include + +int main(int argc, char ** argv) +{ + rclcpp::init(argc, argv); + testing::InitGoogleTest(&argc, argv); + const int ret = RUN_ALL_TESTS(); + rclcpp::shutdown(); + return ret; +} diff --git a/msg/translation_node/test/pub_sub.cpp b/msg/translation_node/test/pub_sub.cpp new file mode 100644 index 0000000000..3ccd02a9b8 --- /dev/null +++ b/msg/translation_node/test/pub_sub.cpp @@ -0,0 +1,350 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include +using namespace std::chrono_literals; + +// Define a custom struct with MESSAGE_VERSION field that can be used in ROS pubs and subs +#define DEFINE_VERSIONED_ROS_MESSAGE_TYPE(CUSTOM_TYPE_NAME, ROS_TYPE_NAME, THIS_MESSAGE_VERSION) \ + struct CUSTOM_TYPE_NAME : public ROS_TYPE_NAME { \ + CUSTOM_TYPE_NAME() = default; \ + CUSTOM_TYPE_NAME(const ROS_TYPE_NAME& msg) : ROS_TYPE_NAME(msg) {} \ + static constexpr uint32_t MESSAGE_VERSION = THIS_MESSAGE_VERSION; \ + }; \ + template<> \ + struct rclcpp::TypeAdapter \ + { \ + using is_specialized = std::true_type; \ + using custom_type = CUSTOM_TYPE_NAME; \ + using ros_message_type = ROS_TYPE_NAME; \ + static void convert_to_ros_message(const custom_type & source, ros_message_type & destination) \ + { \ + destination = source; \ + } \ + static void convert_to_custom(const ros_message_type & source, custom_type & destination) \ + { \ + destination = source; \ + } \ + }; \ + RCLCPP_USING_CUSTOM_TYPE_AS_ROS_MESSAGE_TYPE(CUSTOM_TYPE_NAME, ROS_TYPE_NAME); + +class PubSubGraphTest : public testing::Test +{ +protected: + void SetUp() override + { + _test_node = std::make_shared("test_node"); + _app_node = std::make_shared("app_node"); + _executor.add_node(_test_node); + _executor.add_node(_app_node); + + for (auto& node : {_app_node, _test_node}) { + auto ret = rcutils_logging_set_logger_level( + node->get_logger().get_name(), RCUTILS_LOG_SEVERITY_DEBUG); + if (ret != RCUTILS_RET_OK) { + RCLCPP_ERROR( + node->get_logger(), "Error setting severity: %s", + rcutils_get_error_string().str); + rcutils_reset_error(); + } + } + } + + bool spinWithTimeout(const std::function& predicate) { + const auto start = _app_node->now(); + while (_app_node->now() - start < 5s) { + _executor.spin_some(); + if (predicate()) { + return true; + } + } + return false; + } + + std::shared_ptr _test_node; + std::shared_ptr _app_node; + rclcpp::executors::SingleThreadedExecutor _executor; +}; + +class RegisteredTranslationsTest : public RegisteredTranslations { +public: + RegisteredTranslationsTest() = default; +}; + + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(Float32Versioned, std_msgs::msg::Float32, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(ColorRGBAVersioned, std_msgs::msg::ColorRGBA, 2u); + +class DirectTranslationTest { +public: + using MessageOlder = Float32Versioned; + using MessageNewer = ColorRGBAVersioned; + + static constexpr const char* kTopic = "test/direct_translation"; + + static void fromOlder(const MessageOlder &msg_older, MessageNewer &msg_newer) { + msg_newer.r = 1.f; + msg_newer.g = msg_older.data; + msg_newer.b = 2.f; + } + + static void toOlder(const MessageNewer &msg_newer, MessageOlder &msg_older) { + msg_older.data = msg_newer.r + msg_newer.g + msg_newer.b; + } +}; + + +TEST_F(PubSubGraphTest, DirectTranslation) +{ + RegisteredTranslationsTest registered_translations; + registered_translations.registerDirectTranslation(); + + PubSubGraph graph(*_test_node, registered_translations.topicTranslations()); + Monitor monitor(*_test_node, &graph, nullptr); + + const std::string topic_name = DirectTranslationTest::kTopic; + const std::string topic_name_older_version = getVersionedTopicName(topic_name, DirectTranslationTest::MessageOlder::MESSAGE_VERSION); + const std::string topic_name_newer_version = getVersionedTopicName(topic_name, DirectTranslationTest::MessageNewer::MESSAGE_VERSION); + + { + // Create publisher + subscriber + int num_topic_updates = 0; + DirectTranslationTest::MessageNewer latest_data{}; + auto publisher = _app_node->create_publisher(topic_name_older_version, + rclcpp::QoS(1).best_effort()); + auto subscriber = _app_node->create_subscription(topic_name_newer_version, + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data, this]( + DirectTranslationTest::MessageNewer::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated: %.3f", (double) msg->g); + latest_data = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&subscriber, &publisher]() { + return subscriber->get_publisher_count() > 0 && publisher->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + DirectTranslationTest::MessageOlder msg_older; + msg_older.data = (float) i; + publisher->publish(msg_older); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == i + 1; + })) << "Timeout, topic update not received, i=" << i; + + // Check data + EXPECT_FLOAT_EQ(latest_data.r, 1.f); + EXPECT_FLOAT_EQ(latest_data.g, (float) i); + EXPECT_FLOAT_EQ(latest_data.b, 2.f); + } + } + + // Now check the translation into the other direction + { + int num_topic_updates = 0; + DirectTranslationTest::MessageOlder latest_data{}; + auto publisher = _app_node->create_publisher(topic_name_newer_version, + rclcpp::QoS(1).best_effort()); + auto subscriber = _app_node->create_subscription(topic_name_older_version, + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data, this]( + DirectTranslationTest::MessageOlder::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated: %.3f", (double) msg->data); + latest_data = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&subscriber, &publisher]() { + return subscriber->get_publisher_count() > 0 && publisher->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + DirectTranslationTest::MessageNewer msg_newer; + msg_newer.r = (float)i; + msg_newer.g = (float)i * 10.f; + msg_newer.b = (float)i * 100.f; + publisher->publish(msg_newer); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == i + 1; + })) << "Timeout, topic update not received, i=" << i; + + // Check data + EXPECT_FLOAT_EQ(latest_data.data, 111.f * (float)i); + } + } +} + + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV1, std_msgs::msg::Float32, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV1, std_msgs::msg::Float64, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeCV1, std_msgs::msg::Int64, 1u); + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV2, std_msgs::msg::ColorRGBA, 2u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV2, std_msgs::msg::Int64, 2u); + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV3, std_msgs::msg::Float64, 3u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV3, std_msgs::msg::Int64, 3u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeCV3, std_msgs::msg::Float32, 3u); + +class TranslationMultiTestV2 { +public: + using MessagesOlder = TypesArray; + static constexpr const char* kTopicsOlder[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + "test/multi_translation_topic_c", + }; + static_assert(MessageTypeAV1::MESSAGE_VERSION == 1); + static_assert(MessageTypeBV1::MESSAGE_VERSION == 1); + static_assert(MessageTypeCV1::MESSAGE_VERSION == 1); + + using MessagesNewer = TypesArray; + static constexpr const char* kTopicsNewer[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + }; + static_assert(MessageTypeAV2::MESSAGE_VERSION == 2); + static_assert(MessageTypeBV2::MESSAGE_VERSION == 2); + + static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, + const MessagesOlder::Type3 &msg_older3, + MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2) { + msg_newer1.r = msg_older1.data; + msg_newer1.g = (float)msg_older2.data; + msg_newer1.b = (float)msg_older3.data; + msg_newer2.data = msg_older3.data * 10; + } + static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, + MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2, MessagesOlder::Type3 &msg_older3) { + msg_older1.data = msg_newer1.r; + msg_older2.data = msg_newer1.g; + msg_older3.data = msg_newer2.data / 10; + } +}; + +class TranslationMultiTestV3 { +public: + using MessagesOlder = TypesArray; + static constexpr const char* kTopicsOlder[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + }; + + using MessagesNewer = TypesArray; + static constexpr const char* kTopicsNewer[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + "test/multi_translation_topic_c", + }; + + static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, + MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2, MessagesNewer::Type3 &msg_newer3) { + msg_newer1.data = msg_older1.r; + msg_newer2.data = (int64_t)msg_older1.g; + msg_newer3.data = (float)msg_older2.data + 100; + } + static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, const MessagesNewer::Type3 &msg_newer3, + MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2) { + msg_older1.r = (float)msg_newer1.data; + msg_older1.g = (float)msg_newer2.data; + msg_older2.data = (int64_t)msg_newer3.data - 100; + } +}; + +TEST_F(PubSubGraphTest, TranslationMulti) { + RegisteredTranslationsTest registered_translations; + // Register 3 different message versions, with 3 types -> 2 types -> 3 types + registered_translations.registerTranslation(); + registered_translations.registerTranslation(); + + PubSubGraph graph(*_test_node, registered_translations.topicTranslations()); + Monitor monitor(*_test_node, &graph, nullptr); + + const std::string topic_name_a = TranslationMultiTestV2::kTopicsOlder[0]; + const std::string topic_name_b = TranslationMultiTestV2::kTopicsOlder[1]; + const std::string topic_name_c = TranslationMultiTestV2::kTopicsOlder[2]; + + // Create publishers for version 1 + subscribers for version 3 + int num_topic_updates = 0; + MessageTypeAV3 latest_data_a{}; + MessageTypeBV3 latest_data_b{}; + MessageTypeCV3 latest_data_c{}; + auto publisher_a = _app_node->create_publisher(getVersionedTopicName(topic_name_a, MessageTypeAV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto publisher_b = _app_node->create_publisher(getVersionedTopicName(topic_name_b, MessageTypeBV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto publisher_c = _app_node->create_publisher(getVersionedTopicName(topic_name_c, MessageTypeCV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto subscriber_a = _app_node->create_subscription(getVersionedTopicName(topic_name_a, MessageTypeAV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_a, this]( + MessageTypeAV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (A): %.3f", (double) msg->data); + latest_data_a = *msg; + ++num_topic_updates; + }); + auto subscriber_b = _app_node->create_subscription(getVersionedTopicName(topic_name_b, MessageTypeBV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_b, this]( + MessageTypeBV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (B): %.3f", (double) msg->data); + latest_data_b = *msg; + ++num_topic_updates; + }); + auto subscriber_c = _app_node->create_subscription(getVersionedTopicName(topic_name_c, MessageTypeCV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_c, this]( + MessageTypeCV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (C): %.3f", (double) msg->data); + latest_data_c = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&]() { + return subscriber_a->get_publisher_count() > 0 && subscriber_b->get_publisher_count() > 0 && subscriber_c->get_publisher_count() > 0 && + publisher_a->get_subscription_count() > 0 && publisher_b->get_subscription_count() > 0 && publisher_c->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + MessageTypeAV1 msg_older_a; + msg_older_a.data = (float) i; + publisher_a->publish(msg_older_a); + + MessageTypeBV1 msg_older_b; + msg_older_b.data = (float) i * 10.f; + publisher_b->publish(msg_older_b); + + MessageTypeCV1 msg_older_c; + msg_older_c.data = i * 100; + publisher_c->publish(msg_older_c); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == (i + 1) * 3; + })) << "Timeout, topic update not received, i=" << i << ", num updates=" << num_topic_updates; + + // Check data + EXPECT_FLOAT_EQ(latest_data_a.data, (float)i); + EXPECT_FLOAT_EQ(latest_data_b.data, (float)i * 10.f); + EXPECT_FLOAT_EQ(latest_data_c.data, ((float)i * 100.f) * 10.f + 100.f); + } +} diff --git a/msg/translation_node/test/services.cpp b/msg/translation_node/test/services.cpp new file mode 100644 index 0000000000..f98569b55d --- /dev/null +++ b/msg/translation_node/test/services.cpp @@ -0,0 +1,215 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include + +using namespace std::chrono_literals; + + +class ServiceTest : public testing::Test +{ +protected: + void SetUp() override + { + _test_node = std::make_shared("test_node"); + _app_node = std::make_shared("app_node"); + _executor.add_node(_test_node); + _executor.add_node(_app_node); + + for (auto& node : {_app_node, _test_node}) { + auto ret = rcutils_logging_set_logger_level( + node->get_logger().get_name(), RCUTILS_LOG_SEVERITY_DEBUG); + if (ret != RCUTILS_RET_OK) { + RCLCPP_ERROR( + node->get_logger(), "Error setting severity: %s", + rcutils_get_error_string().str); + rcutils_reset_error(); + } + } + } + + bool spinWithTimeout(const std::function& predicate) { + const auto start = _app_node->now(); + while (_app_node->now() - start < 5s) { + _executor.spin_some(); + if (predicate()) { + return true; + } + } + return false; + } + + std::shared_ptr _test_node; + std::shared_ptr _app_node; + rclcpp::executors::SingleThreadedExecutor _executor; +}; + +class RegisteredTranslationsTest : public RegisteredTranslations { +public: + RegisteredTranslationsTest() = default; +}; + + +class ServiceTestV0V1 { +public: + using MessageOlder = translation_node::srv::TestV0; + using MessageNewer = translation_node::srv::TestV1; + + static constexpr const char* kTopic = "test/service"; + + static void fromOlder(const MessageOlder::Request &msg_older, MessageNewer::Request &msg_newer) { + msg_newer.request_a = msg_older.request_a; + } + + static void toOlder(const MessageNewer::Request &msg_newer, MessageOlder::Request &msg_older) { + msg_older.request_a = msg_newer.request_a; + } + + static void fromOlder(const MessageOlder::Response &msg_older, MessageNewer::Response &msg_newer) { + msg_newer.response_a = msg_older.response_a; + } + + static void toOlder(const MessageNewer::Response &msg_newer, MessageOlder::Response &msg_older) { + msg_older.response_a = msg_newer.response_a; + } +}; + +class ServiceTestV1V2 { +public: + using MessageOlder = translation_node::srv::TestV1; + using MessageNewer = translation_node::srv::TestV2; + + static constexpr const char* kTopic = "test/service"; + + static void fromOlder(const MessageOlder::Request &msg_older, MessageNewer::Request &msg_newer) { + msg_newer.request_a = msg_older.request_a; + msg_newer.request_b = 1234; + } + + static void toOlder(const MessageNewer::Request &msg_newer, MessageOlder::Request &msg_older) { + msg_older.request_a = msg_newer.request_a + msg_newer.request_b; + } + + static void fromOlder(const MessageOlder::Response &msg_older, MessageNewer::Response &msg_newer) { + msg_newer.response_a = msg_older.response_a; + msg_newer.response_b = 32; + } + + static void toOlder(const MessageNewer::Response &msg_newer, MessageOlder::Response &msg_older) { + msg_older.response_a = msg_newer.response_a + msg_newer.response_b; + } +}; + + +TEST_F(ServiceTest, Test) +{ + RegisteredTranslationsTest registered_translations; + registered_translations.registerServiceDirectTranslation(); + registered_translations.registerServiceDirectTranslation(); + + ServiceGraph graph(*_test_node, registered_translations.serviceTranslations()); + Monitor monitor(*_test_node, nullptr, &graph); + + const std::string topic_name = ServiceTestV1V2::kTopic; + const std::string topic_name_v0 = getVersionedTopicName(topic_name, ServiceTestV0V1::MessageOlder::Request::MESSAGE_VERSION); + const std::string topic_name_v1 = getVersionedTopicName(topic_name, ServiceTestV0V1::MessageNewer::Request::MESSAGE_VERSION); + const std::string topic_name_v2 = getVersionedTopicName(topic_name, ServiceTestV1V2::MessageNewer::Request::MESSAGE_VERSION); + + + // Create service + clients + int num_service_requests = 0; + auto service = _app_node->create_service(topic_name_v0, [&num_service_requests]( + const ServiceTestV0V1::MessageOlder::Request::SharedPtr request, ServiceTestV0V1::MessageOlder::Response::SharedPtr response) { + response->response_a = request->request_a + 1; + ++num_service_requests; + }); + auto client0 = _app_node->create_client(topic_name_v0); + auto client1 = _app_node->create_client(topic_name_v1); + auto client2 = _app_node->create_client(topic_name_v2); + + monitor.updateNow(); + + // Wait until there is a service for each client + ASSERT_TRUE(spinWithTimeout([&client0, &client1, &client2]() { + return client0->service_is_ready() && client1->service_is_ready() && client2->service_is_ready(); + })) << "Timeout, no service for clients found: " << client0->service_is_ready() << client1->service_is_ready() << client2->service_is_ready(); + + + + // Make some requests + int expected_num_service_requests = 1; + + // Client 1 + for (int i = 0; i < 10; ++i) { + auto request = std::make_shared(); + ServiceTestV0V1::MessageNewer::Response response; + request->request_a = i; + bool got_response = false; + client1->async_send_request(request, [&got_response, &response](rclcpp::Client::SharedFuture result) { + got_response = true; + response = *result.get(); + }); + + ASSERT_TRUE(spinWithTimeout([&got_response]() { + return got_response; + })) << "Timeout, reply not received, i=" << i; + + // Check data + EXPECT_EQ(response.response_a, i + 1); + EXPECT_EQ(num_service_requests, expected_num_service_requests); + ++expected_num_service_requests; + } + + // Client 0 + for (int i = 0; i < 10; ++i) { + auto request = std::make_shared(); + ServiceTestV0V1::MessageOlder::Response response; + request->request_a = i * 10; + bool got_response = false; + client0->async_send_request(request, [&got_response, &response](rclcpp::Client::SharedFuture result) { + got_response = true; + response = *result.get(); + }); + + ASSERT_TRUE(spinWithTimeout([&got_response]() { + return got_response; + })) << "Timeout, reply not received, i=" << i; + + // Check data + EXPECT_EQ(response.response_a, i * 10 + 1); + EXPECT_EQ(num_service_requests, expected_num_service_requests); + ++expected_num_service_requests; + } + + // Client 2 + for (int i = 0; i < 10; ++i) { + auto request = std::make_shared(); + ServiceTestV1V2::MessageNewer::Response response; + request->request_a = i * 10; + request->request_b = i; + bool got_response = false; + client2->async_send_request(request, [&got_response, &response](rclcpp::Client::SharedFuture result) { + got_response = true; + response = *result.get(); + }); + + ASSERT_TRUE(spinWithTimeout([&got_response]() { + return got_response; + })) << "Timeout, reply not received, i=" << i; + + // Check data + EXPECT_EQ(response.response_a, i + i * 10 + 1); + EXPECT_EQ(response.response_b, 32); + EXPECT_EQ(num_service_requests, expected_num_service_requests); + ++expected_num_service_requests; + } +} diff --git a/msg/translation_node/test/srv/TestV0.srv b/msg/translation_node/test/srv/TestV0.srv new file mode 100644 index 0000000000..4b06c4af35 --- /dev/null +++ b/msg/translation_node/test/srv/TestV0.srv @@ -0,0 +1,4 @@ +uint32 MESSAGE_VERSION = 0 +uint8 request_a +--- +uint64 response_a diff --git a/msg/translation_node/test/srv/TestV1.srv b/msg/translation_node/test/srv/TestV1.srv new file mode 100644 index 0000000000..d5efaac707 --- /dev/null +++ b/msg/translation_node/test/srv/TestV1.srv @@ -0,0 +1,4 @@ +uint32 MESSAGE_VERSION = 1 +uint64 request_a +--- +uint8 response_a diff --git a/msg/translation_node/test/srv/TestV2.srv b/msg/translation_node/test/srv/TestV2.srv new file mode 100644 index 0000000000..da8b79325f --- /dev/null +++ b/msg/translation_node/test/srv/TestV2.srv @@ -0,0 +1,6 @@ +uint32 MESSAGE_VERSION = 2 +uint8 request_a +uint64 request_b +--- +uint16 response_a +uint64 response_b diff --git a/msg/translation_node/translations/all_translations.h b/msg/translation_node/translations/all_translations.h new file mode 100644 index 0000000000..2b3c2030b4 --- /dev/null +++ b/msg/translation_node/translations/all_translations.h @@ -0,0 +1,11 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include + +//#include "example_translation_direct_v1.h" +//#include "example_translation_multi_v2.h" +//#include "example_translation_service_v1.h" diff --git a/msg/translation_node/translations/example_translation_direct_v1.h b/msg/translation_node/translations/example_translation_direct_v1.h new file mode 100644 index 0000000000..b2c51aa628 --- /dev/null +++ b/msg/translation_node/translations/example_translation_direct_v1.h @@ -0,0 +1,30 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +// Translate ExampleTopic v0 <--> v1 +#include +#include + +class ExampleTopicV1Translation { +public: + using MessageOlder = px4_msgs_old::msg::ExampleTopicV0; + static_assert(MessageOlder::MESSAGE_VERSION == 0); + + using MessageNewer = px4_msgs::msg::ExampleTopic; + static_assert(MessageNewer::MESSAGE_VERSION == 1); + + static constexpr const char* kTopic = "fmu/out/example_topic"; + + static void fromOlder(const MessageOlder &msg_older, MessageNewer &msg_newer) { + // Set msg_newer from msg_older + } + + static void toOlder(const MessageNewer &msg_newer, MessageOlder &msg_older) { + // Set msg_older from msg_newer + } +}; + +REGISTER_TOPIC_TRANSLATION_DIRECT(ExampleTopicV1Translation); diff --git a/msg/translation_node/translations/example_translation_multi_v2.h b/msg/translation_node/translations/example_translation_multi_v2.h new file mode 100644 index 0000000000..9c55deace6 --- /dev/null +++ b/msg/translation_node/translations/example_translation_multi_v2.h @@ -0,0 +1,42 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +// Translate ExampleTopic and OtherTopic v1 <--> v2 +#include +#include +#include +#include + +class ExampleTopicOtherTopicV2Translation { +public: + using MessagesOlder = TypesArray; + static constexpr const char* kTopicsOlder[] = { + "fmu/out/example_topic", + "fmu/out/other_topic", + }; + static_assert(px4_msgs_old::msg::ExampleTopicV1::MESSAGE_VERSION == 1); + static_assert(px4_msgs_old::msg::OtherTopicV1::MESSAGE_VERSION == 1); + + using MessagesNewer = TypesArray; + static constexpr const char* kTopicsNewer[] = { + "fmu/out/example_topic", + "fmu/out/other_topic", + }; + static_assert(px4_msgs::msg::ExampleTopic::MESSAGE_VERSION == 2); + static_assert(px4_msgs::msg::OtherTopic::MESSAGE_VERSION == 2); + + static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, + MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2) { + // Set msg_newer1, msg_newer2 from msg_older1, msg_older2 + } + + static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, + MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2) { + // Set msg_older1, msg_older2 from msg_newer1, msg_newer2 + } +}; + +REGISTER_TOPIC_TRANSLATION(ExampleTopicOtherTopicV2Translation); diff --git a/msg/translation_node/translations/example_translation_service_v1.h b/msg/translation_node/translations/example_translation_service_v1.h new file mode 100644 index 0000000000..e98599ab01 --- /dev/null +++ b/msg/translation_node/translations/example_translation_service_v1.h @@ -0,0 +1,38 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +// Translate ExampleService v0 <--> v1 +#include +#include + +class ExampleServiceV1Translation { +public: + using MessageOlder = px4_msgs_old::srv::ExampleServiceV0; + static_assert(MessageOlder::Request::MESSAGE_VERSION == 0); + + using MessageNewer = px4_msgs::srv::ExampleService; + static_assert(MessageNewer::Request::MESSAGE_VERSION == 1); + + static constexpr const char* kTopic = "fmu/example_service"; + + static void fromOlder(const MessageOlder::Request &msg_older, MessageNewer::Request &msg_newer) { + // Request: set msg_newer from msg_older + } + + static void toOlder(const MessageNewer::Request &msg_newer, MessageOlder::Request &msg_older) { + // Request: set msg_older from msg_newer + } + + static void fromOlder(const MessageOlder::Response &msg_older, MessageNewer::Response &msg_newer) { + // Response: set msg_newer from msg_older + } + + static void toOlder(const MessageNewer::Response &msg_newer, MessageOlder::Response &msg_older) { + // Response: set msg_older from msg_newer + } +}; + +REGISTER_SERVICE_TRANSLATION_DIRECT(ExampleServiceV1Translation);