diff --git a/src/Alloc.h b/src/Alloc.h index 3797ac7..131d522 100644 --- a/src/Alloc.h +++ b/src/Alloc.h @@ -74,25 +74,40 @@ public: logVerb("alloc create object with old memory %d @%p", allocSize(), obj); return obj; } - UsedMemory += allocSize(); + long size = allocSize(); + if (MaxMemory == 0) { + UsedMemory += size; + } else { + // Reserve memory atomically to avoid race with other threads. + long cur = UsedMemory; + while (true) { + long next = cur + size; + if (next > MaxMemory) { + Throw(MemLimit, "maxmemory used"); + } + if (AtomicCAS(UsedMemory, cur, next)) { + break; + } + } + } if (MaxMemory == 0 || UsedMemory <= MaxMemory) { - void* p = ::operator new(allocSize(), std::nothrow); + void* p = ::operator new(size, std::nothrow); if (p) { try { obj = new (p) T(args...); logVerb("alloc create object with new memory %d @%p", allocSize(), obj); return obj; } catch (...) { - UsedMemory -= allocSize(); + UsedMemory -= size; ::operator delete(p); throw; } } else { - UsedMemory -= allocSize(); + UsedMemory -= size; Throw(MemLimit, "system memory alloc fail"); } } else { - UsedMemory -= allocSize(); + UsedMemory -= size; Throw(MemLimit, "maxmemory used"); } return nullptr; @@ -122,7 +137,7 @@ thread_local T* Alloc::Free[CacheSize]; template thread_local int Alloc::Size = 0; -template +template class RefCntObj { public: @@ -134,7 +149,11 @@ public: RefCntObj& operator=(const RefCntObj&) = delete; int count() const { +#ifndef _PREDIXY_SINGLE_THREAD_ + return mCnt.load(); +#else return mCnt; +#endif } void ref() { @@ -154,7 +173,11 @@ public: protected: ~RefCntObj() { +#ifndef _PREDIXY_SINGLE_THREAD_ + mCnt.store(0); +#else mCnt = 0; +#endif } private: CntType mCnt; diff --git a/src/Buffer.cpp b/src/Buffer.cpp index d6d89b4..6128dbf 100644 --- a/src/Buffer.cpp +++ b/src/Buffer.cpp @@ -208,6 +208,11 @@ Buffer* Segment::vfset(Buffer* buf, const char* fmt, va_list ap) mBegin.buf = buf; mBegin.pos = pos; mCur = mBegin; + if (!nbuf) { + // Keep segment empty if formatting fails (e.g., oversized payload). + mEnd = mBegin; + return nullptr; + } mEnd.buf = nbuf; mEnd.pos = nbuf->length(); return nbuf; diff --git a/src/Logger.cpp b/src/Logger.cpp index 718d5e5..b819224 100644 --- a/src/Logger.cpp +++ b/src/Logger.cpp @@ -128,8 +128,12 @@ void Logger::run() mCond.wait(lck); } logs.swap(mLogs); +#ifndef _PREDIXY_SINGLE_THREAD_ + missLogs = mMissLogs.exchange(0); +#else missLogs = mMissLogs; mMissLogs = 0; +#endif } while (false); if (mFileSink) { mFileSink->checkRotate(); diff --git a/src/Logger.h b/src/Logger.h index 2b372f2..3b04e05 100644 --- a/src/Logger.h +++ b/src/Logger.h @@ -13,6 +13,7 @@ #include #include #include +#include "Sync.h" #include "Exception.h" class LogFileSink; @@ -98,7 +99,7 @@ private: private: bool mStop; bool mAllowMissLog; - long mMissLogs; + AtomicLong mMissLogs; int mLogSample[LogLevel::Sentinel]; unsigned mLogUnitCnt; std::vector mLogs; @@ -123,9 +124,12 @@ private: #define logMacroImpl(lvl, fmt, ...) \ do { \ - if (auto _lu_ = Logger::gInst->log(lvl)) { \ - _lu_->format(lvl, __FILE__, __LINE__, fmt, ##__VA_ARGS__);\ - Logger::gInst->put(_lu_); \ + Logger* _logger_ = Logger::gInst; \ + if (_logger_) { \ + if (auto _lu_ = _logger_->log(lvl)) { \ + _lu_->format(lvl, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + _logger_->put(_lu_); \ + } \ } \ } while(0) diff --git a/src/ResponseParser.cpp b/src/ResponseParser.cpp index 7a3bf38..bf7ac95 100644 --- a/src/ResponseParser.cpp +++ b/src/ResponseParser.cpp @@ -208,7 +208,12 @@ ResponseParser::Status ResponseParser::parse(Buffer* buf, int& pos) case SubStringBody: if (mStringCnt + (end - cursor) > mStringLen) { cursor += mStringLen - mStringCnt; - *cursor == '\r' ? mState = ElementLF : error = __LINE__; + mStringCnt = mStringLen; + if (cursor >= end) { + error = __LINE__; + } else { + *cursor == '\r' ? mState = ElementLF : error = __LINE__; + } } else { mStringCnt += end - cursor; cursor = end - 1; diff --git a/test/pubsub_large_message.py b/test/pubsub_large_message.py new file mode 100644 index 0000000..0ab16de --- /dev/null +++ b/test/pubsub_large_message.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# +# Ensure large pubsub payloads are parsed correctly. +# + +from test_util import parse_args, make_clients, exit_with_result + + +def normalize_bytes(value): + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + +def wait_for_type(ps, msg_type, attempts=20, timeout=1.0): + for _ in range(attempts): + msg = ps.get_message(timeout=timeout) + if msg and msg.get("type") == msg_type: + return msg + return None + + +def run_test(host, port): + c1, c2 = make_clients(host, port, count=2) + + ps = c1.pubsub() + ps.subscribe("big_payload") + msg = wait_for_type(ps, "subscribe") + if not msg: + print("FAIL: missing subscribe confirmation") + return False + + payload = "x" * 10000 + publish_result = c2.publish("big_payload", payload) + if publish_result < 1: + print("FAIL: publish did not reach subscribers:", publish_result) + return False + + msg = wait_for_type(ps, "message", attempts=30, timeout=1.0) + if not msg: + print("FAIL: missing message response") + return False + + data = normalize_bytes(msg.get("data")) + if data != payload: + print("FAIL: payload mismatch (len)", len(data) if data else 0) + return False + + return True + + +if __name__ == "__main__": + args = parse_args("Pubsub large message test") + success = run_test(args.host, args.port) + exit_with_result(success, "pubsub large message", + "pubsub large message") diff --git a/test/run.sh b/test/run.sh index 66911c1..62daf0f 100755 --- a/test/run.sh +++ b/test/run.sh @@ -141,6 +141,7 @@ TESTS=( "test/pubsub_parser_reset.py" "test/null_response_handling.py" "test/pubsub_long_name.py" + "test/pubsub_large_message.py" "test/transaction_forbid.py" "test/mget_wrong_type.py" "test/msetnx_atomicity.py"