Fix thread-safety and parser edge cases

This commit is contained in:
Julien Letessier 2026-01-15 11:18:51 +01:00
parent d49841dbe9
commit 4cbeaeda1b
7 changed files with 109 additions and 11 deletions

View File

@ -74,25 +74,40 @@ public:
logVerb("alloc create object with old memory %d @%p", allocSize<T>(), obj); logVerb("alloc create object with old memory %d @%p", allocSize<T>(), obj);
return obj; return obj;
} }
UsedMemory += allocSize<T>(); long size = allocSize<T>();
if (MaxMemory == 0) {
UsedMemory += size;
} else {
// Reserve memory atomically to avoid race with other threads.
long cur = UsedMemory;
while (true) {
long next = cur + size;
if (next > MaxMemory) {
Throw(MemLimit, "maxmemory used");
}
if (AtomicCAS(UsedMemory, cur, next)) {
break;
}
}
}
if (MaxMemory == 0 || UsedMemory <= MaxMemory) { if (MaxMemory == 0 || UsedMemory <= MaxMemory) {
void* p = ::operator new(allocSize<T>(), std::nothrow); void* p = ::operator new(size, std::nothrow);
if (p) { if (p) {
try { try {
obj = new (p) T(args...); obj = new (p) T(args...);
logVerb("alloc create object with new memory %d @%p", allocSize<T>(), obj); logVerb("alloc create object with new memory %d @%p", allocSize<T>(), obj);
return obj; return obj;
} catch (...) { } catch (...) {
UsedMemory -= allocSize<T>(); UsedMemory -= size;
::operator delete(p); ::operator delete(p);
throw; throw;
} }
} else { } else {
UsedMemory -= allocSize<T>(); UsedMemory -= size;
Throw(MemLimit, "system memory alloc fail"); Throw(MemLimit, "system memory alloc fail");
} }
} else { } else {
UsedMemory -= allocSize<T>(); UsedMemory -= size;
Throw(MemLimit, "maxmemory used"); Throw(MemLimit, "maxmemory used");
} }
return nullptr; return nullptr;
@ -122,7 +137,7 @@ thread_local T* Alloc<T, CacheSize>::Free[CacheSize];
template<class T, int CacheSize> template<class T, int CacheSize>
thread_local int Alloc<T, CacheSize>::Size = 0; thread_local int Alloc<T, CacheSize>::Size = 0;
template<class T, class CntType = int> template<class T, class CntType = AtomicInt>
class RefCntObj class RefCntObj
{ {
public: public:
@ -134,7 +149,11 @@ public:
RefCntObj& operator=(const RefCntObj&) = delete; RefCntObj& operator=(const RefCntObj&) = delete;
int count() const int count() const
{ {
#ifndef _PREDIXY_SINGLE_THREAD_
return mCnt.load();
#else
return mCnt; return mCnt;
#endif
} }
void ref() void ref()
{ {
@ -154,7 +173,11 @@ public:
protected: protected:
~RefCntObj() ~RefCntObj()
{ {
#ifndef _PREDIXY_SINGLE_THREAD_
mCnt.store(0);
#else
mCnt = 0; mCnt = 0;
#endif
} }
private: private:
CntType mCnt; CntType mCnt;

View File

@ -208,6 +208,11 @@ Buffer* Segment::vfset(Buffer* buf, const char* fmt, va_list ap)
mBegin.buf = buf; mBegin.buf = buf;
mBegin.pos = pos; mBegin.pos = pos;
mCur = mBegin; mCur = mBegin;
if (!nbuf) {
// Keep segment empty if formatting fails (e.g., oversized payload).
mEnd = mBegin;
return nullptr;
}
mEnd.buf = nbuf; mEnd.buf = nbuf;
mEnd.pos = nbuf->length(); mEnd.pos = nbuf->length();
return nbuf; return nbuf;

View File

@ -128,8 +128,12 @@ void Logger::run()
mCond.wait(lck); mCond.wait(lck);
} }
logs.swap(mLogs); logs.swap(mLogs);
#ifndef _PREDIXY_SINGLE_THREAD_
missLogs = mMissLogs.exchange(0);
#else
missLogs = mMissLogs; missLogs = mMissLogs;
mMissLogs = 0; mMissLogs = 0;
#endif
} while (false); } while (false);
if (mFileSink) { if (mFileSink) {
mFileSink->checkRotate(); mFileSink->checkRotate();

View File

@ -13,6 +13,7 @@
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <thread> #include <thread>
#include "Sync.h"
#include "Exception.h" #include "Exception.h"
class LogFileSink; class LogFileSink;
@ -98,7 +99,7 @@ private:
private: private:
bool mStop; bool mStop;
bool mAllowMissLog; bool mAllowMissLog;
long mMissLogs; AtomicLong mMissLogs;
int mLogSample[LogLevel::Sentinel]; int mLogSample[LogLevel::Sentinel];
unsigned mLogUnitCnt; unsigned mLogUnitCnt;
std::vector<LogUnit*> mLogs; std::vector<LogUnit*> mLogs;
@ -123,9 +124,12 @@ private:
#define logMacroImpl(lvl, fmt, ...) \ #define logMacroImpl(lvl, fmt, ...) \
do { \ do { \
if (auto _lu_ = Logger::gInst->log(lvl)) { \ Logger* _logger_ = Logger::gInst; \
_lu_->format(lvl, __FILE__, __LINE__, fmt, ##__VA_ARGS__);\ if (_logger_) { \
Logger::gInst->put(_lu_); \ if (auto _lu_ = _logger_->log(lvl)) { \
_lu_->format(lvl, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \
_logger_->put(_lu_); \
} \
} \ } \
} while(0) } while(0)

View File

@ -208,7 +208,12 @@ ResponseParser::Status ResponseParser::parse(Buffer* buf, int& pos)
case SubStringBody: case SubStringBody:
if (mStringCnt + (end - cursor) > mStringLen) { if (mStringCnt + (end - cursor) > mStringLen) {
cursor += mStringLen - mStringCnt; cursor += mStringLen - mStringCnt;
*cursor == '\r' ? mState = ElementLF : error = __LINE__; mStringCnt = mStringLen;
if (cursor >= end) {
error = __LINE__;
} else {
*cursor == '\r' ? mState = ElementLF : error = __LINE__;
}
} else { } else {
mStringCnt += end - cursor; mStringCnt += end - cursor;
cursor = end - 1; cursor = end - 1;

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
#
# Ensure large pubsub payloads are parsed correctly.
#
from test_util import parse_args, make_clients, exit_with_result
def normalize_bytes(value):
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def wait_for_type(ps, msg_type, attempts=20, timeout=1.0):
for _ in range(attempts):
msg = ps.get_message(timeout=timeout)
if msg and msg.get("type") == msg_type:
return msg
return None
def run_test(host, port):
c1, c2 = make_clients(host, port, count=2)
ps = c1.pubsub()
ps.subscribe("big_payload")
msg = wait_for_type(ps, "subscribe")
if not msg:
print("FAIL: missing subscribe confirmation")
return False
payload = "x" * 10000
publish_result = c2.publish("big_payload", payload)
if publish_result < 1:
print("FAIL: publish did not reach subscribers:", publish_result)
return False
msg = wait_for_type(ps, "message", attempts=30, timeout=1.0)
if not msg:
print("FAIL: missing message response")
return False
data = normalize_bytes(msg.get("data"))
if data != payload:
print("FAIL: payload mismatch (len)", len(data) if data else 0)
return False
return True
if __name__ == "__main__":
args = parse_args("Pubsub large message test")
success = run_test(args.host, args.port)
exit_with_result(success, "pubsub large message",
"pubsub large message")

View File

@ -141,6 +141,7 @@ TESTS=(
"test/pubsub_parser_reset.py" "test/pubsub_parser_reset.py"
"test/null_response_handling.py" "test/null_response_handling.py"
"test/pubsub_long_name.py" "test/pubsub_long_name.py"
"test/pubsub_large_message.py"
"test/transaction_forbid.py" "test/transaction_forbid.py"
"test/mget_wrong_type.py" "test/mget_wrong_type.py"
"test/msetnx_atomicity.py" "test/msetnx_atomicity.py"