From 512faeea2894a6c37c4d4f3eebb34a5e1587b23f Mon Sep 17 00:00:00 2001 From: Nick Bolton Date: Wed, 16 Oct 2024 16:12:17 +0100 Subject: [PATCH] test: Modularize hello back logic and add tests --- cspell.json | 4 + src/lib/client/Client.cpp | 80 ++---- src/lib/client/Client.h | 4 +- src/lib/client/HelloBack.cpp | 103 ++++++++ src/lib/client/HelloBack.h | 70 ++++++ src/lib/deskflow/ProtocolUtil.cpp | 2 +- src/lib/deskflow/protocol_types.cpp | 2 + src/lib/deskflow/protocol_types.h | 12 + src/lib/server/ClientProxyUnknown.cpp | 7 +- src/test/mock/io/MockStream.h | 4 + src/test/unittests/client/HelloBackTests.cpp | 229 ++++++++++++++++++ .../unittests/gui/ipc/QIpcClientTests.cpp | 4 + 12 files changed, 454 insertions(+), 67 deletions(-) create mode 100644 src/lib/client/HelloBack.cpp create mode 100644 src/lib/client/HelloBack.h create mode 100644 src/test/unittests/client/HelloBackTests.cpp diff --git a/cspell.json b/cspell.json index da81bfedc..2f40e8ea5 100644 --- a/cspell.json +++ b/cspell.json @@ -48,12 +48,14 @@ "pacman", "Petroules", "Pixmap", + "Pointee", "Poschta", "Povilas", "Priddy", "psutil", "pyproject", "qputenv", + "readf", "Regen", "Repology", "runas", @@ -67,12 +69,14 @@ "Trixie", "unittests", "Valgrind", + "vbuffer", "vcpkg", "venv", "vmactions", "Volker", "whot", "winget", + "writef", "XWINDOWS" ], "ignoreWords": [], diff --git a/src/lib/client/Client.cpp b/src/lib/client/Client.cpp index 573552f8c..987b047d0 100644 --- a/src/lib/client/Client.cpp +++ b/src/lib/client/Client.cpp @@ -47,8 +47,11 @@ #include #include #include +#include #include +using namespace deskflow::client; + // // Client // @@ -97,6 +100,19 @@ Client::Client( new TMethodEventJob(this, &Client::handleFileRecieveCompleted) ); } + + m_pHelloBack = std::make_unique(std::make_shared( + [this]() { + sendConnectionFailedEvent("got invalid hello message from server"); + cleanupTimer(); + cleanupConnection(); + }, + [this](int major, int minor) { + sendConnectionFailedEvent(XIncompatibleClient(major, minor).what()); + cleanupTimer(); + cleanupConnection(); + } + )); } Client::~Client() @@ -570,7 +586,7 @@ void Client::cleanupStream() void Client::handleConnected(const Event &, void *) { - LOG((CLOG_DEBUG1 "connected; wait for hello")); + LOG((CLOG_DEBUG1 "connected, waiting for hello")); cleanupConnecting(); setupConnection(); @@ -651,69 +667,9 @@ void Client::handleClipboardGrabbed(const Event &event, void *) } } -bool Client::isCompatible(int major, int minor) const -{ - const std::map> compatibleTable{ - {6, {7, 8}}, // 1.6 is compatible with 1.7 and 1.8 - {7, {8}} // 1.7 is compatible with 1.8 - }; - - bool isCompatible = false; - - if (major == kProtocolMajorVersion) { - auto versions = compatibleTable.find(minor); - if (versions != compatibleTable.end()) { - auto compatibleVersions = versions->second; - isCompatible = compatibleVersions.find(kProtocolMinorVersion) != compatibleVersions.end(); - } - } - - return isCompatible; -} - void Client::handleHello(const Event &, void *) { - SInt16 major, minor; - - // as luck would have it, both "Synergy" and "Barrier" are 7 chars, - // so we eat 7 chars and then test for either protocol name. - // we cannot re-use `readf` to check for various hello messages, - // as `readf` eats bytes (advances the stream position reference). - std::string protocolName; - ProtocolUtil::readf(m_stream, kMsgHello, &protocolName, &major, &minor); - - if (protocolName != kSynergyProtocolName && protocolName != kBarrierProtocolName) { - sendConnectionFailedEvent("got invalid hello message from server"); - cleanupTimer(); - cleanupConnection(); - return; - } - - // check versions - LOG_DEBUG("got hello version %s, %d.%d", protocolName.c_str(), major, minor); - - SInt16 helloBackMajor = kProtocolMajorVersion; - SInt16 helloBackMinor = kProtocolMinorVersion; - - if (isCompatible(major, minor)) { - // because 1.6 is compatable with 1.7 and 1.8 - downgrading protocol for - // server - LOG_NOTE("downgrading protocol version for server"); - helloBackMinor = minor; - } else if (major < kProtocolMajorVersion || (major == kProtocolMajorVersion && minor < kProtocolMinorVersion)) { - sendConnectionFailedEvent(XIncompatibleClient(major, minor).what()); - cleanupTimer(); - cleanupConnection(); - return; - } - - // say hello back - LOG_DEBUG("say hello version %s %d.%d", protocolName.c_str(), helloBackMajor, helloBackMinor); - - // dynamically build write format for hello back since `writef` doesn't - // support fixed length strings yet. - std::string helloBackMessage = protocolName + "%2i%2i%s"; - ProtocolUtil::writef(m_stream, helloBackMessage.c_str(), helloBackMajor, helloBackMinor, &m_name); + m_pHelloBack->handleHello(m_stream, m_name); // now connected but waiting to complete handshake setupScreen(); diff --git a/src/lib/client/Client.h b/src/lib/client/Client.h index 7460b0bb1..f55d8b659 100644 --- a/src/lib/client/Client.h +++ b/src/lib/client/Client.h @@ -20,6 +20,7 @@ #include "deskflow/IClient.h" +#include "HelloBack.h" #include "base/EventTypes.h" #include "deskflow/ClientArgs.h" #include "deskflow/Clipboard.h" @@ -27,6 +28,7 @@ #include "deskflow/INode.h" #include "mt/CondVar.h" #include "net/NetworkAddress.h" + #include class EventQueueTimer; @@ -217,7 +219,6 @@ private: void handleDisconnected(const Event &, void *); void handleShapeChanged(const Event &, void *); void handleClipboardGrabbed(const Event &, void *); - bool isCompatible(int major, int minor) const; void handleHello(const Event &, void *); void handleSuspend(const Event &event, void *); void handleResume(const Event &event, void *); @@ -260,4 +261,5 @@ private: size_t m_maximumClipboardSize; deskflow::ClientArgs m_args; size_t m_resolvedAddressesCount = 0; + std::unique_ptr m_pHelloBack; }; diff --git a/src/lib/client/HelloBack.cpp b/src/lib/client/HelloBack.cpp new file mode 100644 index 000000000..65d5b0909 --- /dev/null +++ b/src/lib/client/HelloBack.cpp @@ -0,0 +1,103 @@ +/* + * Deskflow -- mouse and keyboard sharing utility + * + * SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd. + * SPDX-License-Identifier: GPL-2.0 + */ + +#include "HelloBack.h" + +#include "base/Log.h" +#include "deskflow/ProtocolUtil.h" +#include "deskflow/protocol_types.h" + +#include +#include + +namespace deskflow::client { + +// +// HelloBack::Deps +// + +void HelloBack::Deps::invalidHello() +{ + m_invalidHello(); +} + +void HelloBack::Deps::incompatible(int major, int minor) +{ + m_incompatible(major, minor); +} + +// +// HelloBack +// + +void HelloBack::handleHello(deskflow::IStream *stream, const std::string &clientName) const +{ + SInt16 serverMajor; + SInt16 serverMinor; + + // as luck would have it, both "Synergy" and "Barrier" are 7 chars, + // so we eat 7 chars and then test for either protocol name. + // we cannot re-use `readf` to check for various hello messages, + // as `readf` eats bytes (advances the stream position reference). + std::string protocolName; + ProtocolUtil::readf(stream, kMsgHello, &protocolName, &serverMajor, &serverMinor); + + if (protocolName != kSynergyProtocolName && protocolName != kBarrierProtocolName) { + m_deps->invalidHello(); + return; + } + + // check versions + LOG_DEBUG("got hello version %s, %d.%d", protocolName.c_str(), serverMajor, serverMinor); + + const auto helloBackMajor = m_majorVersion; + auto helloBackMinor = m_minorVersion; + + if (shouldDowngrade(serverMajor, serverMinor)) { + LOG_NOTE("downgrading to %d.%d protocol for server", serverMajor, serverMinor); + helloBackMinor = serverMinor; + } else if (serverMajor < m_majorVersion || (serverMajor == m_majorVersion && serverMinor < m_minorVersion)) { + m_deps->incompatible(serverMajor, serverMinor); + return; + } + + // say hello back with same protocol name and version + LOG_DEBUG( + "saying hello back with version %s %d.%d", // + protocolName.c_str(), helloBackMajor, helloBackMinor + ); + + // dynamically build write format for hello back since `ProtocolUtil::writef` + // doesn't support formatting fixed length strings yet. + std::string helloBackMessage = protocolName + kMsgHelloBackArgs; + ProtocolUtil::writef(stream, helloBackMessage.c_str(), helloBackMajor, helloBackMinor, &clientName); +} + +bool HelloBack::shouldDowngrade(int major, int minor) const +{ + const std::map> map{ + // 1.6 is compatible with 1.7 and 1.8 + {6, {7, 8}}, + + // 1.7 is compatible with 1.8 + {7, {8}}, + }; + + if (major == m_majorVersion) { + auto versions = map.find(minor); + if (versions != map.end()) { + auto compatibleVersions = versions->second; + if (compatibleVersions.find(m_minorVersion) != compatibleVersions.end()) { + return true; + } + } + } + + return false; +} + +} // namespace deskflow::client diff --git a/src/lib/client/HelloBack.h b/src/lib/client/HelloBack.h new file mode 100644 index 000000000..88e576f2b --- /dev/null +++ b/src/lib/client/HelloBack.h @@ -0,0 +1,70 @@ +/* + * Deskflow -- mouse and keyboard sharing utility + * + * SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd. + * SPDX-License-Identifier: GPL-2.0 + */ + +#pragma once + +#include "common/basic_types.h" +#include "deskflow/protocol_types.h" +#include "io/IStream.h" + +#include +#include + +namespace deskflow::client { + +class HelloBack +{ +public: + struct Deps + { + Deps() = default; + explicit Deps(std::function invalidHello, std::function incompatible) + : m_invalidHello(std::move(invalidHello)), + m_incompatible(std::move(incompatible)) + { + } + virtual ~Deps() = default; + + /** + * @brief Call when invalid hello message received from server. + */ + virtual void invalidHello(); + + /** + * @brief Call when the client is incompatible with the server. + */ + virtual void incompatible(int major, int minor); + + private: + std::function m_invalidHello; + std::function m_incompatible; + }; + + explicit HelloBack( + std::shared_ptr deps, const SInt16 majorVersion = kProtocolMajorVersion, + const SInt16 minorVersion = kProtocolMinorVersion + ) + : m_deps(deps), + m_majorVersion(majorVersion), + m_minorVersion(minorVersion) + { + } + + /** + * @brief Handle hello message from server and reply with hello back. + */ + void handleHello(deskflow::IStream *stream, const std::string &clientName) const; + +private: + bool shouldDowngrade(int major, int minor) const; + + std::shared_ptr m_deps; + SInt16 m_majorVersion; + SInt16 m_minorVersion; +}; + +} // namespace deskflow::client diff --git a/src/lib/deskflow/ProtocolUtil.cpp b/src/lib/deskflow/ProtocolUtil.cpp index db9cfae9d..9aeb9c0ea 100644 --- a/src/lib/deskflow/ProtocolUtil.cpp +++ b/src/lib/deskflow/ProtocolUtil.cpp @@ -506,7 +506,7 @@ void ProtocolUtil::readBytes(deskflow::IStream *stream, UInt32 len, String *dest UInt8 buffer[128]; // when string length is 0, this implies that the size of the string is - // variable and embedded and will be found in the stream. + // variable and will be embedded in the stream. if (len == 0) { len = read4BytesInt(stream); } diff --git a/src/lib/deskflow/protocol_types.cpp b/src/lib/deskflow/protocol_types.cpp index 7f58c91cf..a57266a23 100644 --- a/src/lib/deskflow/protocol_types.cpp +++ b/src/lib/deskflow/protocol_types.cpp @@ -24,7 +24,9 @@ const char *const kBarrierProtocolName = "Barrier"; // The protocol name string within the hello and hello back messages must be // 7 chars for backward compatibility (Synergy and Barrier are 7 chars). const char *const kMsgHello = "%7s%2i%2i"; +const char *const kMsgHelloArgs = "%2i%2i"; const char *const kMsgHelloBack = "%7s%2i%2i%s"; +const char *const kMsgHelloBackArgs = "%2i%2i%s"; const char *const kMsgCNoop = "CNOP"; const char *const kMsgCClose = "CBYE"; const char *const kMsgCEnter = "CINN%2i%2i%4i%2i"; diff --git a/src/lib/deskflow/protocol_types.h b/src/lib/deskflow/protocol_types.h index c13157e0e..0a7d6e99f 100644 --- a/src/lib/deskflow/protocol_types.h +++ b/src/lib/deskflow/protocol_types.h @@ -118,12 +118,24 @@ extern const char *const kBarrierProtocolName; // keyboard layout list. extern const char *const kMsgHello; +// args part of kMsgHello. +// used as part of a dynamic hello message. +// this can be superseded by kMsgHello once `ProtocolUtil::writef` +// supports fixed length strings (e.g. %7s). +extern const char *const kMsgHelloArgs; + // respond to hello from server; secondary -> primary // $1 = protocol major version number supported by client. $2 = // protocol minor version number supported by client. $3 = client // name. extern const char *const kMsgHelloBack; +// args part of kMsgHelloBack. +// used as part of a dynamic hello message. +// this can be superseded by kMsgHelloBack once `ProtocolUtil::writef` +// supports fixed length strings (e.g. %7s). +extern const char *const kMsgHelloBackArgs; + // // command codes // diff --git a/src/lib/server/ClientProxyUnknown.cpp b/src/lib/server/ClientProxyUnknown.cpp index 277f7da6c..0ed18fe33 100644 --- a/src/lib/server/ClientProxyUnknown.cpp +++ b/src/lib/server/ClientProxyUnknown.cpp @@ -60,8 +60,10 @@ ClientProxyUnknown::ClientProxyUnknown(deskflow::IStream *stream, double timeout m_timer = m_events->newOneShotTimer(timeout, this); addStreamHandlers(); + std::string helloMessage = std::string(kSynergyProtocolName).append(kMsgHelloArgs); + LOG((CLOG_DEBUG1 "saying hello")); - ProtocolUtil::writef(m_stream, kMsgHello, kProtocolMajorVersion, kProtocolMinorVersion); + ProtocolUtil::writef(m_stream, helloMessage.c_str(), kProtocolMajorVersion, kProtocolMinorVersion); } ClientProxyUnknown::~ClientProxyUnknown() @@ -225,8 +227,7 @@ void ClientProxyUnknown::handleData(const Event &, void *) // parse the reply to hello SInt16 major, minor; std::string protocolName; - if (!ProtocolUtil::readf( - m_stream, kMsgHelloBack, &protocolName, &major, &minor, &name)) { + if (!ProtocolUtil::readf(m_stream, kMsgHelloBack, &protocolName, &major, &minor, &name)) { throw XBadClient(); } diff --git a/src/test/mock/io/MockStream.h b/src/test/mock/io/MockStream.h index 9c4aca6f6..add4f7ac8 100644 --- a/src/test/mock/io/MockStream.h +++ b/src/test/mock/io/MockStream.h @@ -24,6 +24,8 @@ class IEventQueue; +namespace { + class MockStream : public deskflow::IStream { public: @@ -44,3 +46,5 @@ public: MOCK_METHOD(bool, isReady, (), (const, override)); MOCK_METHOD(UInt32, getSize, (), (const, override)); }; + +} // namespace diff --git a/src/test/unittests/client/HelloBackTests.cpp b/src/test/unittests/client/HelloBackTests.cpp new file mode 100644 index 000000000..1e8ed9a09 --- /dev/null +++ b/src/test/unittests/client/HelloBackTests.cpp @@ -0,0 +1,229 @@ +/* + * Deskflow -- mouse and keyboard sharing utility + * + * SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd. + * SPDX-License-Identifier: GPL-2.0 + */ + +#include "client/HelloBack.h" + +#include "common/basic_types.h" +#include "mock/io/MockStream.h" + +#include +#include +#include +#include +#include +#include + +using HelloBack = deskflow::client::HelloBack; +using namespace testing; + +namespace { + +class MockDeps : public HelloBack::Deps +{ +public: + ~MockDeps() override = default; + MOCK_METHOD(void, invalidHello, (), (override)); + MOCK_METHOD(void, incompatible, (int major, int minor), (override)); +}; + +void intTo2ByteBuf(SInt16 value, std::array &buf) +{ + buf[0] = static_cast((value >> 8) & 0xFF); // MSB + buf[1] = static_cast(value & 0xFF); // LSB +} + +void intTo4ByteBuf(SInt32 value, std::array &buf) +{ + buf[0] = static_cast((value >> 24) & 0xFF); // MSB + buf[1] = static_cast((value >> 16) & 0xFF); + buf[2] = static_cast((value >> 8) & 0xFF); + buf[3] = static_cast(value & 0xFF); // LSB +} + +std::string printAsHex(const char *buffer, size_t size) +{ + std::ostringstream hexStream; + for (size_t i = 0; i < size; ++i) { + hexStream << std::hex << std::setw(2) << std::setfill('0') + << static_cast(static_cast(buffer[i])) << " "; + } + return hexStream.str(); +} + +void setupMockHelloRead( + MockStream &stream, const std::string &protocolName, const SInt16 majorVersion, const SInt16 minorVersion +) +{ + + std::array majorBuf; + std::array minorBuf; + intTo2ByteBuf(majorVersion, majorBuf); + intTo2ByteBuf(minorVersion, minorBuf); + + EXPECT_CALL(stream, read(_, _)) + .WillOnce(DoAll( + WithArg<0>([protocolName](void *vbuffer) { + auto buffer = static_cast(vbuffer); + std::copy(protocolName.begin(), protocolName.end(), buffer); + }), + Return(7) + )) + .WillOnce(DoAll( + WithArg<0>([majorBuf](void *vbuffer) { + auto buffer = static_cast(vbuffer); + std::memcpy(buffer, majorBuf.data(), majorBuf.size()); + }), + Return(2) + )) + .WillOnce(DoAll( + WithArg<0>([minorBuf](void *vbuffer) { + auto buffer = static_cast(vbuffer); + std::memcpy(buffer, minorBuf.data(), minorBuf.size()); + }), + Return(2) + )); +} + +void setupMockHelloBackWrite( + MockStream &stream, const std::string &protocolName, const SInt16 majorVersion, const SInt16 minorVersion, + const std::string &name +) +{ + + std::array majorBuf; + std::array minorBuf; + std::array nameLenBuf; + intTo2ByteBuf(majorVersion, majorBuf); + intTo2ByteBuf(minorVersion, minorBuf); + intTo4ByteBuf(static_cast(name.size()), nameLenBuf); + + const auto versionIntSize = 4; + const auto clientNameIntSize = 4; + const UInt32 helloBackSize = + static_cast(protocolName.size() + versionIntSize + clientNameIntSize + name.size()); + + std::vector expect; + expect.reserve(helloBackSize); + expect.insert(expect.end(), protocolName.begin(), protocolName.end()); + expect.insert(expect.end(), majorBuf.begin(), majorBuf.end()); + expect.insert(expect.end(), minorBuf.begin(), minorBuf.end()); + expect.insert(expect.end(), nameLenBuf.begin(), nameLenBuf.end()); + expect.insert(expect.end(), name.begin(), name.end()); + + EXPECT_CALL(stream, write(_, helloBackSize)).WillOnce(WithArg<0>([expect, helloBackSize](const void *vbuffer) { + const auto buffer = static_cast(vbuffer); + + EXPECT_TRUE(std::memcmp(expect.data(), buffer, helloBackSize) == 0) + << "Buffer mismatch\n" + << "Expected: " << printAsHex(expect.data(), helloBackSize) << "\n" + << "Actual: " << printAsHex(buffer, helloBackSize) << "\n"; + })); +} + +} // namespace + +TEST(HelloBackTests, handleHello_nastyProtocol_invalidHello) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps); + NiceMock stream; + const std::string clientName = "stub"; + + setupMockHelloRead(stream, "ShareMouse", 0, 0); + + EXPECT_CALL(*deps, invalidHello()).Times(1); + + helloBack.handleHello(&stream, clientName); +} + +TEST(HelloBackTests, handleHello_synergyProtocolCurrent_validMessage) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 2); + NiceMock stream; + const std::string clientName = "stub"; + + setupMockHelloRead(stream, "Synergy", 1, 2); + + EXPECT_CALL(*deps, incompatible(_, _)).Times(0); + EXPECT_CALL(*deps, invalidHello()).Times(0); + + helloBack.handleHello(&stream, clientName); +} + +TEST(HelloBackTests, handleHello_barrierProtocolCurrent_validMessage) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 2); + NiceMock stream; + const std::string clientName = "stub"; + + setupMockHelloRead(stream, "Barrier", 1, 2); + + EXPECT_CALL(*deps, incompatible(_, _)).Times(0); + EXPECT_CALL(*deps, invalidHello()).Times(0); + + helloBack.handleHello(&stream, clientName); +} + +TEST(HelloBackTests, handleHello_synergyProtocolOlder_validMessage) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 2); + NiceMock stream; + const std::string clientName = "stub"; + + setupMockHelloRead(stream, "Synergy", 1, 1); + + EXPECT_CALL(*deps, incompatible(1, 1)).Times(1); + + helloBack.handleHello(&stream, clientName); +} + +TEST(HelloBackTests, handleHello_synergyProtocolCurrent_wroteHelloBack) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 2); + NiceMock stream; + const std::string clientName = "test client"; + + setupMockHelloRead(stream, "Synergy", 1, 2); + + setupMockHelloBackWrite(stream, "Synergy", 1, 2, "test client"); + + helloBack.handleHello(&stream, clientName); +} + +TEST(HelloBackTests, handleHello_barrierProtocolCurrent_wroteHelloBack) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 2); + NiceMock stream; + const std::string clientName = "test client"; + + setupMockHelloRead(stream, "Barrier", 1, 2); + + setupMockHelloBackWrite(stream, "Barrier", 1, 2, "test client"); + + helloBack.handleHello(&stream, clientName); +} + +// If the client is protocol version 1.8 and the server is 1.6, the client +// should downgrade and respond with the server version. +TEST(HelloBackTests, handleHello_synergyProtocolCompat_wroteHelloBack) +{ + auto deps = std::make_shared>(); + HelloBack helloBack(deps, 1, 8); + NiceMock stream; + const std::string clientName = "test client"; + + setupMockHelloRead(stream, "Synergy", 1, 6); + + setupMockHelloBackWrite(stream, "Synergy", 1, 6, "test client"); + + helloBack.handleHello(&stream, clientName); +} diff --git a/src/test/unittests/gui/ipc/QIpcClientTests.cpp b/src/test/unittests/gui/ipc/QIpcClientTests.cpp index a13a495cd..29ee6d697 100644 --- a/src/test/unittests/gui/ipc/QIpcClientTests.cpp +++ b/src/test/unittests/gui/ipc/QIpcClientTests.cpp @@ -23,12 +23,16 @@ using testing::_; using testing::StrEq; +namespace { + class MockStream : public QDataStreamProxy { public: MOCK_METHOD(qint64, writeRawData, (const char *, int), (override)); }; +} // namespace + TEST(QIpcClientTests, sendCommand_anyCommand_commandSent) { auto mockStream = std::make_shared();