From ab44559df6ebcd18c4a7fa088dbc608c75981d6e Mon Sep 17 00:00:00 2001 From: Vamshi Maskuri <117595548+varshith257@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:23:22 +0530 Subject: [PATCH] fix(CVE-2021-42076): Enforce maximum message length to prevent memory exhaustion based on barrier: debauchee/barrier@7ab8e01, debauchee/barrier@cc36982, debauchee/barrier@e33c81b, debauchee/barrier@af90f39, debauchee/barrier@fd5295e --- src/lib/base/EventTypes.cpp | 1 + src/lib/base/EventTypes.h | 6 +++++ src/lib/client/ServerProxy.cpp | 30 ++++++++++++++++--------- src/lib/deskflow/PacketStreamFilter.cpp | 20 ++++++++++++----- src/lib/deskflow/PacketStreamFilter.h | 2 +- src/lib/deskflow/ProtocolUtil.cpp | 15 +++++++++++++ src/lib/deskflow/protocol_types.h | 8 +++++++ src/lib/net/SecureSocket.cpp | 6 +++++ src/lib/net/TCPSocket.cpp | 6 +++++ src/lib/server/ClientProxy1_0.cpp | 30 +++++++++++++++++-------- src/lib/server/ClientProxyUnknown.cpp | 5 +++++ 11 files changed, 103 insertions(+), 26 deletions(-) diff --git a/src/lib/base/EventTypes.cpp b/src/lib/base/EventTypes.cpp index 1c0d990b1..9ac976d18 100644 --- a/src/lib/base/EventTypes.cpp +++ b/src/lib/base/EventTypes.cpp @@ -54,6 +54,7 @@ REGISTER_EVENT(IStream, outputFlushed) REGISTER_EVENT(IStream, outputError) REGISTER_EVENT(IStream, inputShutdown) REGISTER_EVENT(IStream, outputShutdown) +REGISTER_EVENT(IStream, inputFormatError) // // IpcClient diff --git a/src/lib/base/EventTypes.h b/src/lib/base/EventTypes.h index 6ddfe8adf..206d4edf6 100644 --- a/src/lib/base/EventTypes.h +++ b/src/lib/base/EventTypes.h @@ -140,6 +140,11 @@ public: */ Event::Type outputShutdown(); + /** Get input format error event type + This is sent when a stream receives an irrecoverable input format error. + */ + Event::Type inputFormatError(); + //@} private: @@ -148,6 +153,7 @@ private: Event::Type m_outputError; Event::Type m_inputShutdown; Event::Type m_outputShutdown; + Event::Type m_inputFormatError; }; class IpcClientEvents : public EventTypes diff --git a/src/lib/client/ServerProxy.cpp b/src/lib/client/ServerProxy.cpp index 5b34f54b5..38d82ff12 100644 --- a/src/lib/client/ServerProxy.cpp +++ b/src/lib/client/ServerProxy.cpp @@ -29,6 +29,7 @@ #include "deskflow/FileChunk.h" #include "deskflow/ProtocolUtil.h" #include "deskflow/StreamChunker.h" +#include "deskflow/XDeskflow.h" #include "deskflow/option_types.h" #include "deskflow/protocol_types.h" #include "io/IStream.h" @@ -121,19 +122,26 @@ void ServerProxy::handleData(const Event &, void *) // parse message LOG((CLOG_DEBUG2 "msg from server: %c%c%c%c", code[0], code[1], code[2], code[3])); - switch ((this->*m_parser)(code)) { - case kOkay: - break; + try { + switch ((this->*m_parser)(code)) { + case kOkay: + break; - case kUnknown: - LOG((CLOG_ERR "invalid message from server: %c%c%c%c", code[0], code[1], code[2], code[3])); - // not possible to determine message boundaries - // read the whole stream to discard unkonwn data - while (m_stream->read(nullptr, 4)) - ; - break; + case kUnknown: + LOG((CLOG_ERR "invalid message from server: %c%c%c%c", code[0], code[1], code[2], code[3])); + // not possible to determine message boundaries + // read the whole stream to discard unkonwn data + while (m_stream->read(nullptr, 4)) + ; + break; - case kDisconnect: + case kDisconnect: + return; + } + } catch (const XBadClient &e) { + LOG((CLOG_ERR "protocol error from server: %s", e.what())); + ProtocolUtil::writef(m_stream, kMsgEBad); + m_client->disconnect("invalid message from server"); return; } diff --git a/src/lib/deskflow/PacketStreamFilter.cpp b/src/lib/deskflow/PacketStreamFilter.cpp index d4db25005..d4d819b8e 100644 --- a/src/lib/deskflow/PacketStreamFilter.cpp +++ b/src/lib/deskflow/PacketStreamFilter.cpp @@ -19,6 +19,7 @@ #include "deskflow/PacketStreamFilter.h" #include "base/IEventQueue.h" #include "base/TMethodEventJob.h" +#include "deskflow/protocol_types.h" #include "mt/Lock.h" #include @@ -125,7 +126,7 @@ bool PacketStreamFilter::isReadyNoLock() const return (m_size != 0 && m_buffer.getSize() >= m_size); } -void PacketStreamFilter::readPacketSize() +bool PacketStreamFilter::readPacketSize() { // note -- m_mutex must be locked on entry @@ -134,7 +135,12 @@ void PacketStreamFilter::readPacketSize() memcpy(buffer, m_buffer.peek(sizeof(buffer)), sizeof(buffer)); m_buffer.pop(sizeof(buffer)); m_size = ((UInt32)buffer[0] << 24) | ((UInt32)buffer[1] << 16) | ((UInt32)buffer[2] << 8) | (UInt32)buffer[3]; + if (m_size > PROTOCOL_MAX_MESSAGE_LENGTH) { + m_events->addEvent(Event(m_events->forIStream().inputFormatError(), getEventTarget())); + return false; + } } + return true; } bool PacketStreamFilter::readMore() @@ -147,13 +153,17 @@ bool PacketStreamFilter::readMore() UInt32 n = getStream()->read(buffer, sizeof(buffer)); while (n > 0) { m_buffer.write(buffer, n); + + // if we don't yet have the next packet size then get it, if possible. + // Note that we can't wait for whole pending data to arrive because it may be huge in + // case of malicious or erroneous peer. + if (!readPacketSize()) { + break; + } + n = getStream()->read(buffer, sizeof(buffer)); } - // if we don't yet have the next packet size then get it, - // if possible. - readPacketSize(); - // note if we now have a whole packet bool isReady = isReadyNoLock(); diff --git a/src/lib/deskflow/PacketStreamFilter.h b/src/lib/deskflow/PacketStreamFilter.h index aa40793a3..5ea66e28c 100644 --- a/src/lib/deskflow/PacketStreamFilter.h +++ b/src/lib/deskflow/PacketStreamFilter.h @@ -48,7 +48,7 @@ protected: private: bool isReadyNoLock() const; - void readPacketSize(); + bool readPacketSize(); bool readMore(); private: diff --git a/src/lib/deskflow/ProtocolUtil.cpp b/src/lib/deskflow/ProtocolUtil.cpp index 9aeb9c0ea..bab69bbbc 100644 --- a/src/lib/deskflow/ProtocolUtil.cpp +++ b/src/lib/deskflow/ProtocolUtil.cpp @@ -19,6 +19,8 @@ #include "deskflow/ProtocolUtil.h" #include "base/Log.h" #include "common/stdvector.h" +#include "deskflow/XDeskflow.h" +#include "deskflow/protocol_types.h" #include "io/IStream.h" #include #include @@ -175,6 +177,13 @@ void ProtocolUtil::vreadf(deskflow::IStream *stream, const char *fmt, va_list ar case 'I': { void *destination = va_arg(args, void *); + UInt32 n = read4BytesInt(stream); + + if (n > PROTOCOL_MAX_LIST_LENGTH) { + LOG((CLOG_ERR "read: vector length exceeds maximum allowed size: %u", n)); + throw XBadClient("Too long message received"); + } + switch (len) { case 1: // 1 byte integer @@ -199,6 +208,12 @@ void ProtocolUtil::vreadf(deskflow::IStream *stream, const char *fmt, va_list ar case 's': { String *destination = va_arg(args, String *); + + if (len > PROTOCOL_MAX_STRING_LENGTH) { + LOG((CLOG_ERR "read: string length exceeds maximum allowed size: %u", len)); + throw XBadClient("Too long message received"); + } + readBytes(stream, len, destination); break; } diff --git a/src/lib/deskflow/protocol_types.h b/src/lib/deskflow/protocol_types.h index 0a7d6e99f..e6ed7af00 100644 --- a/src/lib/deskflow/protocol_types.h +++ b/src/lib/deskflow/protocol_types.h @@ -20,6 +20,8 @@ #include "base/EventTypes.h" +#include + // protocol version number // 1.0: initial protocol // 1.1: adds KeyCode to key press, release, and repeat @@ -53,6 +55,12 @@ static const double kKeepAlivesUntilDeath = 3.0; static const double kHeartRate = -1.0; static const double kHeartBeatsUntilDeath = 3.0; +// Messages of very large size indicate a likely protocol error. We don't parse such messages and +// drop connection instead. Note that e.g. the clipboard messages are already limited to 32kB. +static constexpr std::uint32_t PROTOCOL_MAX_MESSAGE_LENGTH = 4 * 1024 * 1024; +static constexpr std::uint32_t PROTOCOL_MAX_LIST_LENGTH = 1024 * 1024; +static constexpr std::uint32_t PROTOCOL_MAX_STRING_LENGTH = 1024 * 1024; + // direction constants enum EDirection { diff --git a/src/lib/net/SecureSocket.cpp b/src/lib/net/SecureSocket.cpp index e77f70612..dc97bed37 100644 --- a/src/lib/net/SecureSocket.cpp +++ b/src/lib/net/SecureSocket.cpp @@ -41,6 +41,8 @@ #define MAX_ERROR_SIZE 65535 +static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024; + static const float s_retryDelay = 0.01f; enum @@ -147,6 +149,10 @@ TCPSocket::EJobResult SecureSocket::doRead() do { m_inputBuffer.write(buffer, bytesRead); + if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) { + break; + } + status = secureRead(buffer, sizeof(buffer), bytesRead); if (status < 0) { return kBreak; diff --git a/src/lib/net/TCPSocket.cpp b/src/lib/net/TCPSocket.cpp index b145dae3a..714c346ad 100644 --- a/src/lib/net/TCPSocket.cpp +++ b/src/lib/net/TCPSocket.cpp @@ -33,6 +33,8 @@ #include #include +static const std::size_t MAX_INPUT_BUFFER_SIZE = 1024 * 1024; + // // TCPSocket // @@ -324,6 +326,10 @@ TCPSocket::EJobResult TCPSocket::doRead() do { m_inputBuffer.write(buffer, static_cast(bytesRead)); + if (m_inputBuffer.getSize() > MAX_INPUT_BUFFER_SIZE) { + break; + } + bytesRead = ARCH->readSocket(m_socket, buffer, sizeof(buffer)); } while (bytesRead > 0); diff --git a/src/lib/server/ClientProxy1_0.cpp b/src/lib/server/ClientProxy1_0.cpp index 31b51020f..02f0e98b4 100644 --- a/src/lib/server/ClientProxy1_0.cpp +++ b/src/lib/server/ClientProxy1_0.cpp @@ -50,6 +50,10 @@ ClientProxy1_0::ClientProxy1_0(const String &name, deskflow::IStream *stream, IE m_events->forIStream().inputShutdown(), stream->getEventTarget(), new TMethodEventJob(this, &ClientProxy1_0::handleDisconnect, NULL) ); + m_events->adoptHandler( + m_events->forIStream().inputFormatError(), stream->getEventTarget(), + new TMethodEventJob(this, &ClientProxy1_0::handleDisconnect, NULL) + ); m_events->adoptHandler( m_events->forIStream().outputShutdown(), stream->getEventTarget(), new TMethodEventJob(this, &ClientProxy1_0::handleWriteError, NULL) @@ -83,6 +87,7 @@ void ClientProxy1_0::removeHandlers() m_events->removeHandler(m_events->forIStream().outputError(), getStream()->getEventTarget()); m_events->removeHandler(m_events->forIStream().inputShutdown(), getStream()->getEventTarget()); m_events->removeHandler(m_events->forIStream().outputShutdown(), getStream()->getEventTarget()); + m_events->removeHandler(m_events->forIStream().inputFormatError(), getStream()->getEventTarget()); m_events->removeHandler(Event::kTimer, this); // remove timer @@ -135,15 +140,22 @@ void ClientProxy1_0::handleData(const Event &, void *) } // parse message - LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); - if (!(this->*m_parser)(code)) { - LOG(( - CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3] - )); - // not possible to determine message boundaries - // read the whole stream to discard unkonwn data - while (getStream()->read(nullptr, 4)) - ; + try { + LOG((CLOG_DEBUG2 "msg from \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], code[3])); + if (!(this->*m_parser)(code)) { + LOG( + (CLOG_ERR "invalid message from client \"%s\": %c%c%c%c", getName().c_str(), code[0], code[1], code[2], + code[3]) + ); + // not possible to determine message boundaries + // read the whole stream to discard unkonwn data + while (getStream()->read(nullptr, 4)) + ; + } + } catch (const XBadClient &e) { + LOG((CLOG_ERR "protocol error from client \"%s\": %s", getName().c_str(), e.what())); + disconnect(); + return; } // next message diff --git a/src/lib/server/ClientProxyUnknown.cpp b/src/lib/server/ClientProxyUnknown.cpp index eaebd7069..de202660c 100644 --- a/src/lib/server/ClientProxyUnknown.cpp +++ b/src/lib/server/ClientProxyUnknown.cpp @@ -120,6 +120,10 @@ void ClientProxyUnknown::addStreamHandlers() m_events->forIStream().inputShutdown(), m_stream->getEventTarget(), new TMethodEventJob(this, &ClientProxyUnknown::handleDisconnect) ); + m_events->adoptHandler( + m_events->forIStream().inputFormatError(), m_stream->getEventTarget(), + new TMethodEventJob(this, &ClientProxyUnknown::handleDisconnect) + ); m_events->adoptHandler( m_events->forIStream().outputShutdown(), m_stream->getEventTarget(), new TMethodEventJob(this, &ClientProxyUnknown::handleWriteError) @@ -146,6 +150,7 @@ void ClientProxyUnknown::removeHandlers() m_events->removeHandler(m_events->forIStream().inputReady(), m_stream->getEventTarget()); m_events->removeHandler(m_events->forIStream().outputError(), m_stream->getEventTarget()); m_events->removeHandler(m_events->forIStream().inputShutdown(), m_stream->getEventTarget()); + m_events->removeHandler(m_events->forIStream().inputFormatError(), m_stream->getEventTarget()); m_events->removeHandler(m_events->forIStream().outputShutdown(), m_stream->getEventTarget()); } if (m_proxy != NULL) {