test: Modularize hello back logic and add tests

This commit is contained in:
Nick Bolton
2024-10-16 16:12:17 +01:00
parent 47366f1272
commit 512faeea28
12 changed files with 454 additions and 67 deletions

View File

@ -48,12 +48,14 @@
"pacman",
"Petroules",
"Pixmap",
"Pointee",
"Poschta",
"Povilas",
"Priddy",
"psutil",
"pyproject",
"qputenv",
"readf",
"Regen",
"Repology",
"runas",
@ -67,12 +69,14 @@
"Trixie",
"unittests",
"Valgrind",
"vbuffer",
"vcpkg",
"venv",
"vmactions",
"Volker",
"whot",
"winget",
"writef",
"XWINDOWS"
],
"ignoreWords": [],

View File

@ -47,8 +47,11 @@
#include <cstring>
#include <fstream>
#include <iterator>
#include <memory>
#include <sstream>
using namespace deskflow::client;
//
// Client
//
@ -97,6 +100,19 @@ Client::Client(
new TMethodEventJob<Client>(this, &Client::handleFileRecieveCompleted)
);
}
m_pHelloBack = std::make_unique<HelloBack>(std::make_shared<HelloBack::Deps>(
[this]() {
sendConnectionFailedEvent("got invalid hello message from server");
cleanupTimer();
cleanupConnection();
},
[this](int major, int minor) {
sendConnectionFailedEvent(XIncompatibleClient(major, minor).what());
cleanupTimer();
cleanupConnection();
}
));
}
Client::~Client()
@ -570,7 +586,7 @@ void Client::cleanupStream()
void Client::handleConnected(const Event &, void *)
{
LOG((CLOG_DEBUG1 "connected; wait for hello"));
LOG((CLOG_DEBUG1 "connected, waiting for hello"));
cleanupConnecting();
setupConnection();
@ -651,69 +667,9 @@ void Client::handleClipboardGrabbed(const Event &event, void *)
}
}
bool Client::isCompatible(int major, int minor) const
{
const std::map<int, std::set<int>> compatibleTable{
{6, {7, 8}}, // 1.6 is compatible with 1.7 and 1.8
{7, {8}} // 1.7 is compatible with 1.8
};
bool isCompatible = false;
if (major == kProtocolMajorVersion) {
auto versions = compatibleTable.find(minor);
if (versions != compatibleTable.end()) {
auto compatibleVersions = versions->second;
isCompatible = compatibleVersions.find(kProtocolMinorVersion) != compatibleVersions.end();
}
}
return isCompatible;
}
void Client::handleHello(const Event &, void *)
{
SInt16 major, minor;
// as luck would have it, both "Synergy" and "Barrier" are 7 chars,
// so we eat 7 chars and then test for either protocol name.
// we cannot re-use `readf` to check for various hello messages,
// as `readf` eats bytes (advances the stream position reference).
std::string protocolName;
ProtocolUtil::readf(m_stream, kMsgHello, &protocolName, &major, &minor);
if (protocolName != kSynergyProtocolName && protocolName != kBarrierProtocolName) {
sendConnectionFailedEvent("got invalid hello message from server");
cleanupTimer();
cleanupConnection();
return;
}
// check versions
LOG_DEBUG("got hello version %s, %d.%d", protocolName.c_str(), major, minor);
SInt16 helloBackMajor = kProtocolMajorVersion;
SInt16 helloBackMinor = kProtocolMinorVersion;
if (isCompatible(major, minor)) {
// because 1.6 is compatable with 1.7 and 1.8 - downgrading protocol for
// server
LOG_NOTE("downgrading protocol version for server");
helloBackMinor = minor;
} else if (major < kProtocolMajorVersion || (major == kProtocolMajorVersion && minor < kProtocolMinorVersion)) {
sendConnectionFailedEvent(XIncompatibleClient(major, minor).what());
cleanupTimer();
cleanupConnection();
return;
}
// say hello back
LOG_DEBUG("say hello version %s %d.%d", protocolName.c_str(), helloBackMajor, helloBackMinor);
// dynamically build write format for hello back since `writef` doesn't
// support fixed length strings yet.
std::string helloBackMessage = protocolName + "%2i%2i%s";
ProtocolUtil::writef(m_stream, helloBackMessage.c_str(), helloBackMajor, helloBackMinor, &m_name);
m_pHelloBack->handleHello(m_stream, m_name);
// now connected but waiting to complete handshake
setupScreen();

View File

@ -20,6 +20,7 @@
#include "deskflow/IClient.h"
#include "HelloBack.h"
#include "base/EventTypes.h"
#include "deskflow/ClientArgs.h"
#include "deskflow/Clipboard.h"
@ -27,6 +28,7 @@
#include "deskflow/INode.h"
#include "mt/CondVar.h"
#include "net/NetworkAddress.h"
#include <memory>
class EventQueueTimer;
@ -217,7 +219,6 @@ private:
void handleDisconnected(const Event &, void *);
void handleShapeChanged(const Event &, void *);
void handleClipboardGrabbed(const Event &, void *);
bool isCompatible(int major, int minor) const;
void handleHello(const Event &, void *);
void handleSuspend(const Event &event, void *);
void handleResume(const Event &event, void *);
@ -260,4 +261,5 @@ private:
size_t m_maximumClipboardSize;
deskflow::ClientArgs m_args;
size_t m_resolvedAddressesCount = 0;
std::unique_ptr<deskflow::client::HelloBack> m_pHelloBack;
};

View File

@ -0,0 +1,103 @@
/*
* Deskflow -- mouse and keyboard sharing utility
*
* SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd.
* SPDX-License-Identifier: GPL-2.0
*/
#include "HelloBack.h"
#include "base/Log.h"
#include "deskflow/ProtocolUtil.h"
#include "deskflow/protocol_types.h"
#include <map>
#include <set>
namespace deskflow::client {
//
// HelloBack::Deps
//
void HelloBack::Deps::invalidHello()
{
m_invalidHello();
}
void HelloBack::Deps::incompatible(int major, int minor)
{
m_incompatible(major, minor);
}
//
// HelloBack
//
void HelloBack::handleHello(deskflow::IStream *stream, const std::string &clientName) const
{
SInt16 serverMajor;
SInt16 serverMinor;
// as luck would have it, both "Synergy" and "Barrier" are 7 chars,
// so we eat 7 chars and then test for either protocol name.
// we cannot re-use `readf` to check for various hello messages,
// as `readf` eats bytes (advances the stream position reference).
std::string protocolName;
ProtocolUtil::readf(stream, kMsgHello, &protocolName, &serverMajor, &serverMinor);
if (protocolName != kSynergyProtocolName && protocolName != kBarrierProtocolName) {
m_deps->invalidHello();
return;
}
// check versions
LOG_DEBUG("got hello version %s, %d.%d", protocolName.c_str(), serverMajor, serverMinor);
const auto helloBackMajor = m_majorVersion;
auto helloBackMinor = m_minorVersion;
if (shouldDowngrade(serverMajor, serverMinor)) {
LOG_NOTE("downgrading to %d.%d protocol for server", serverMajor, serverMinor);
helloBackMinor = serverMinor;
} else if (serverMajor < m_majorVersion || (serverMajor == m_majorVersion && serverMinor < m_minorVersion)) {
m_deps->incompatible(serverMajor, serverMinor);
return;
}
// say hello back with same protocol name and version
LOG_DEBUG(
"saying hello back with version %s %d.%d", //
protocolName.c_str(), helloBackMajor, helloBackMinor
);
// dynamically build write format for hello back since `ProtocolUtil::writef`
// doesn't support formatting fixed length strings yet.
std::string helloBackMessage = protocolName + kMsgHelloBackArgs;
ProtocolUtil::writef(stream, helloBackMessage.c_str(), helloBackMajor, helloBackMinor, &clientName);
}
bool HelloBack::shouldDowngrade(int major, int minor) const
{
const std::map<int, std::set<int>> map{
// 1.6 is compatible with 1.7 and 1.8
{6, {7, 8}},
// 1.7 is compatible with 1.8
{7, {8}},
};
if (major == m_majorVersion) {
auto versions = map.find(minor);
if (versions != map.end()) {
auto compatibleVersions = versions->second;
if (compatibleVersions.find(m_minorVersion) != compatibleVersions.end()) {
return true;
}
}
}
return false;
}
} // namespace deskflow::client

View File

@ -0,0 +1,70 @@
/*
* Deskflow -- mouse and keyboard sharing utility
*
* SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd.
* SPDX-License-Identifier: GPL-2.0
*/
#pragma once
#include "common/basic_types.h"
#include "deskflow/protocol_types.h"
#include "io/IStream.h"
#include <functional>
#include <memory>
namespace deskflow::client {
class HelloBack
{
public:
struct Deps
{
Deps() = default;
explicit Deps(std::function<void()> invalidHello, std::function<void(int, int)> incompatible)
: m_invalidHello(std::move(invalidHello)),
m_incompatible(std::move(incompatible))
{
}
virtual ~Deps() = default;
/**
* @brief Call when invalid hello message received from server.
*/
virtual void invalidHello();
/**
* @brief Call when the client is incompatible with the server.
*/
virtual void incompatible(int major, int minor);
private:
std::function<void()> m_invalidHello;
std::function<void(int, int)> m_incompatible;
};
explicit HelloBack(
std::shared_ptr<Deps> deps, const SInt16 majorVersion = kProtocolMajorVersion,
const SInt16 minorVersion = kProtocolMinorVersion
)
: m_deps(deps),
m_majorVersion(majorVersion),
m_minorVersion(minorVersion)
{
}
/**
* @brief Handle hello message from server and reply with hello back.
*/
void handleHello(deskflow::IStream *stream, const std::string &clientName) const;
private:
bool shouldDowngrade(int major, int minor) const;
std::shared_ptr<Deps> m_deps;
SInt16 m_majorVersion;
SInt16 m_minorVersion;
};
} // namespace deskflow::client

View File

@ -506,7 +506,7 @@ void ProtocolUtil::readBytes(deskflow::IStream *stream, UInt32 len, String *dest
UInt8 buffer[128];
// when string length is 0, this implies that the size of the string is
// variable and embedded and will be found in the stream.
// variable and will be embedded in the stream.
if (len == 0) {
len = read4BytesInt(stream);
}

View File

@ -24,7 +24,9 @@ const char *const kBarrierProtocolName = "Barrier";
// The protocol name string within the hello and hello back messages must be
// 7 chars for backward compatibility (Synergy and Barrier are 7 chars).
const char *const kMsgHello = "%7s%2i%2i";
const char *const kMsgHelloArgs = "%2i%2i";
const char *const kMsgHelloBack = "%7s%2i%2i%s";
const char *const kMsgHelloBackArgs = "%2i%2i%s";
const char *const kMsgCNoop = "CNOP";
const char *const kMsgCClose = "CBYE";
const char *const kMsgCEnter = "CINN%2i%2i%4i%2i";

View File

@ -118,12 +118,24 @@ extern const char *const kBarrierProtocolName;
// keyboard layout list.
extern const char *const kMsgHello;
// args part of kMsgHello.
// used as part of a dynamic hello message.
// this can be superseded by kMsgHello once `ProtocolUtil::writef`
// supports fixed length strings (e.g. %7s).
extern const char *const kMsgHelloArgs;
// respond to hello from server; secondary -> primary
// $1 = protocol major version number supported by client. $2 =
// protocol minor version number supported by client. $3 = client
// name.
extern const char *const kMsgHelloBack;
// args part of kMsgHelloBack.
// used as part of a dynamic hello message.
// this can be superseded by kMsgHelloBack once `ProtocolUtil::writef`
// supports fixed length strings (e.g. %7s).
extern const char *const kMsgHelloBackArgs;
//
// command codes
//

View File

@ -60,8 +60,10 @@ ClientProxyUnknown::ClientProxyUnknown(deskflow::IStream *stream, double timeout
m_timer = m_events->newOneShotTimer(timeout, this);
addStreamHandlers();
std::string helloMessage = std::string(kSynergyProtocolName).append(kMsgHelloArgs);
LOG((CLOG_DEBUG1 "saying hello"));
ProtocolUtil::writef(m_stream, kMsgHello, kProtocolMajorVersion, kProtocolMinorVersion);
ProtocolUtil::writef(m_stream, helloMessage.c_str(), kProtocolMajorVersion, kProtocolMinorVersion);
}
ClientProxyUnknown::~ClientProxyUnknown()
@ -225,8 +227,7 @@ void ClientProxyUnknown::handleData(const Event &, void *)
// parse the reply to hello
SInt16 major, minor;
std::string protocolName;
if (!ProtocolUtil::readf(
m_stream, kMsgHelloBack, &protocolName, &major, &minor, &name)) {
if (!ProtocolUtil::readf(m_stream, kMsgHelloBack, &protocolName, &major, &minor, &name)) {
throw XBadClient();
}

View File

@ -24,6 +24,8 @@
class IEventQueue;
namespace {
class MockStream : public deskflow::IStream
{
public:
@ -44,3 +46,5 @@ public:
MOCK_METHOD(bool, isReady, (), (const, override));
MOCK_METHOD(UInt32, getSize, (), (const, override));
};
} // namespace

View File

@ -0,0 +1,229 @@
/*
* Deskflow -- mouse and keyboard sharing utility
*
* SPDX-FileCopyrightText: Copyright (C) 2024 Symless Ltd.
* SPDX-License-Identifier: GPL-2.0
*/
#include "client/HelloBack.h"
#include "common/basic_types.h"
#include "mock/io/MockStream.h"
#include <array>
#include <cstring>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <iomanip>
#include <sstream>
using HelloBack = deskflow::client::HelloBack;
using namespace testing;
namespace {
class MockDeps : public HelloBack::Deps
{
public:
~MockDeps() override = default;
MOCK_METHOD(void, invalidHello, (), (override));
MOCK_METHOD(void, incompatible, (int major, int minor), (override));
};
void intTo2ByteBuf(SInt16 value, std::array<char, 2> &buf)
{
buf[0] = static_cast<char>((value >> 8) & 0xFF); // MSB
buf[1] = static_cast<char>(value & 0xFF); // LSB
}
void intTo4ByteBuf(SInt32 value, std::array<char, 4> &buf)
{
buf[0] = static_cast<char>((value >> 24) & 0xFF); // MSB
buf[1] = static_cast<char>((value >> 16) & 0xFF);
buf[2] = static_cast<char>((value >> 8) & 0xFF);
buf[3] = static_cast<char>(value & 0xFF); // LSB
}
std::string printAsHex(const char *buffer, size_t size)
{
std::ostringstream hexStream;
for (size_t i = 0; i < size; ++i) {
hexStream << std::hex << std::setw(2) << std::setfill('0')
<< static_cast<int>(static_cast<unsigned char>(buffer[i])) << " ";
}
return hexStream.str();
}
void setupMockHelloRead(
MockStream &stream, const std::string &protocolName, const SInt16 majorVersion, const SInt16 minorVersion
)
{
std::array<char, 2> majorBuf;
std::array<char, 2> minorBuf;
intTo2ByteBuf(majorVersion, majorBuf);
intTo2ByteBuf(minorVersion, minorBuf);
EXPECT_CALL(stream, read(_, _))
.WillOnce(DoAll(
WithArg<0>([protocolName](void *vbuffer) {
auto buffer = static_cast<char *>(vbuffer);
std::copy(protocolName.begin(), protocolName.end(), buffer);
}),
Return(7)
))
.WillOnce(DoAll(
WithArg<0>([majorBuf](void *vbuffer) {
auto buffer = static_cast<char *>(vbuffer);
std::memcpy(buffer, majorBuf.data(), majorBuf.size());
}),
Return(2)
))
.WillOnce(DoAll(
WithArg<0>([minorBuf](void *vbuffer) {
auto buffer = static_cast<char *>(vbuffer);
std::memcpy(buffer, minorBuf.data(), minorBuf.size());
}),
Return(2)
));
}
void setupMockHelloBackWrite(
MockStream &stream, const std::string &protocolName, const SInt16 majorVersion, const SInt16 minorVersion,
const std::string &name
)
{
std::array<char, 2> majorBuf;
std::array<char, 2> minorBuf;
std::array<char, 4> nameLenBuf;
intTo2ByteBuf(majorVersion, majorBuf);
intTo2ByteBuf(minorVersion, minorBuf);
intTo4ByteBuf(static_cast<SInt32>(name.size()), nameLenBuf);
const auto versionIntSize = 4;
const auto clientNameIntSize = 4;
const UInt32 helloBackSize =
static_cast<UInt32>(protocolName.size() + versionIntSize + clientNameIntSize + name.size());
std::vector<char> expect;
expect.reserve(helloBackSize);
expect.insert(expect.end(), protocolName.begin(), protocolName.end());
expect.insert(expect.end(), majorBuf.begin(), majorBuf.end());
expect.insert(expect.end(), minorBuf.begin(), minorBuf.end());
expect.insert(expect.end(), nameLenBuf.begin(), nameLenBuf.end());
expect.insert(expect.end(), name.begin(), name.end());
EXPECT_CALL(stream, write(_, helloBackSize)).WillOnce(WithArg<0>([expect, helloBackSize](const void *vbuffer) {
const auto buffer = static_cast<const char *>(vbuffer);
EXPECT_TRUE(std::memcmp(expect.data(), buffer, helloBackSize) == 0)
<< "Buffer mismatch\n"
<< "Expected: " << printAsHex(expect.data(), helloBackSize) << "\n"
<< "Actual: " << printAsHex(buffer, helloBackSize) << "\n";
}));
}
} // namespace
TEST(HelloBackTests, handleHello_nastyProtocol_invalidHello)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps);
NiceMock<MockStream> stream;
const std::string clientName = "stub";
setupMockHelloRead(stream, "ShareMouse", 0, 0);
EXPECT_CALL(*deps, invalidHello()).Times(1);
helloBack.handleHello(&stream, clientName);
}
TEST(HelloBackTests, handleHello_synergyProtocolCurrent_validMessage)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 2);
NiceMock<MockStream> stream;
const std::string clientName = "stub";
setupMockHelloRead(stream, "Synergy", 1, 2);
EXPECT_CALL(*deps, incompatible(_, _)).Times(0);
EXPECT_CALL(*deps, invalidHello()).Times(0);
helloBack.handleHello(&stream, clientName);
}
TEST(HelloBackTests, handleHello_barrierProtocolCurrent_validMessage)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 2);
NiceMock<MockStream> stream;
const std::string clientName = "stub";
setupMockHelloRead(stream, "Barrier", 1, 2);
EXPECT_CALL(*deps, incompatible(_, _)).Times(0);
EXPECT_CALL(*deps, invalidHello()).Times(0);
helloBack.handleHello(&stream, clientName);
}
TEST(HelloBackTests, handleHello_synergyProtocolOlder_validMessage)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 2);
NiceMock<MockStream> stream;
const std::string clientName = "stub";
setupMockHelloRead(stream, "Synergy", 1, 1);
EXPECT_CALL(*deps, incompatible(1, 1)).Times(1);
helloBack.handleHello(&stream, clientName);
}
TEST(HelloBackTests, handleHello_synergyProtocolCurrent_wroteHelloBack)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 2);
NiceMock<MockStream> stream;
const std::string clientName = "test client";
setupMockHelloRead(stream, "Synergy", 1, 2);
setupMockHelloBackWrite(stream, "Synergy", 1, 2, "test client");
helloBack.handleHello(&stream, clientName);
}
TEST(HelloBackTests, handleHello_barrierProtocolCurrent_wroteHelloBack)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 2);
NiceMock<MockStream> stream;
const std::string clientName = "test client";
setupMockHelloRead(stream, "Barrier", 1, 2);
setupMockHelloBackWrite(stream, "Barrier", 1, 2, "test client");
helloBack.handleHello(&stream, clientName);
}
// If the client is protocol version 1.8 and the server is 1.6, the client
// should downgrade and respond with the server version.
TEST(HelloBackTests, handleHello_synergyProtocolCompat_wroteHelloBack)
{
auto deps = std::make_shared<NiceMock<MockDeps>>();
HelloBack helloBack(deps, 1, 8);
NiceMock<MockStream> stream;
const std::string clientName = "test client";
setupMockHelloRead(stream, "Synergy", 1, 6);
setupMockHelloBackWrite(stream, "Synergy", 1, 6, "test client");
helloBack.handleHello(&stream, clientName);
}

View File

@ -23,12 +23,16 @@
using testing::_;
using testing::StrEq;
namespace {
class MockStream : public QDataStreamProxy
{
public:
MOCK_METHOD(qint64, writeRawData, (const char *, int), (override));
};
} // namespace
TEST(QIpcClientTests, sendCommand_anyCommand_commandSent)
{
auto mockStream = std::make_shared<MockStream>();