Skip to content

Commit 3bb5f03

Browse files
authored
[MOD-14732] fix race condition on hybrid (#9029)
* fix race condition on hybrid * remove unrelated file * use Strong/WeakRef instead of refCount * fix the real issue, count proper numShards * fix by setting numShards from IO thread * rephrase * fix potential problem if shards can be 0 * use initialized flag * update comment * remove unused numShards
1 parent d141b63 commit 3bb5f03

4 files changed

Lines changed: 31 additions & 15 deletions

File tree

src/coord/hybrid/dist_hybrid.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ void HybridRequest_buildMRCommand(RedisModuleString **argv, int argc,
255255
MRCommand_appendVsim(xcmd, argv, argc, vsimOffset, &kArgIndex);
256256

257257
// Calculate and apply effective K for KNN queries if SHARD_K_RATIO is set
258+
// TODO: Potentially edit in IO thread where numShards is actually known.
259+
// Now we have a risk that by the time I/O thread sends the command, the number of shards changed, making the effective K inaccurate.
258260
if (vq && vq->type == VECSIM_QT_KNN) {
259261
double shardWindowRatio = vq->knn.shardWindowRatio;
260262
if (shardWindowRatio < MAX_SHARD_WINDOW_RATIO && numShards > 1) {
@@ -713,11 +715,10 @@ static int HybridRequest_executePlan(HybridRequest *hreq, struct ConcurrentCmdCt
713715

714716
// Get the command from the RPNet (it was set during prepareForExecution)
715717
MRCommand *cmd = &searchRPNet->cmd;
716-
int numShards = ConcurrentCmdCtx_GetNumShards(cmdCtx);
717718
cmd->coordStartTime = hreq->profileClocks.coordStartTime;
718719

719720
const RSOomPolicy oomPolicy = hreq->reqConfig.oomPolicy;
720-
if (!ProcessHybridCursorMappings(cmd, numShards, searchMappingsRef, vsimMappingsRef, hreq->tailPipeline->qctx.err, oomPolicy)) {
721+
if (!ProcessHybridCursorMappings(cmd, searchMappingsRef, vsimMappingsRef, hreq->tailPipeline->qctx.err, oomPolicy)) {
721722
// Handle error
722723
StrongRef_Release(searchMappingsRef);
723724
StrongRef_Release(vsimMappingsRef);
@@ -852,7 +853,7 @@ void RSExecDistHybrid(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
852853
// Store coordinator start time for dispatch time tracking
853854
hreq->profileClocks.coordStartTime = ConcurrentCmdCtx_GetCoordStartTime(cmdCtx);
854855

855-
// Get numShards captured from main thread for thread-safe access
856+
// Get numShards captured from main thread for thread-safe access and to compute effective K
856857
size_t numShards = ConcurrentCmdCtx_GetNumShards(cmdCtx);
857858

858859
if (HybridRequest_prepareForExecution(hreq, ctx, argv, argc, sp, numShards, &status) != REDISMODULE_OK) {

src/coord/hybrid/hybrid_cursor_mappings.c

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ typedef struct {
2626
pthread_mutex_t *mutex; // Mutex for array access and completion tracking
2727
pthread_cond_t *completionCond; // Condition variable for completion signaling
2828
int numShards; // Total number of expected shards
29+
bool initialized; // Whether numShards has been set by the IO thread
2930
} processCursorMappingCallbackContext;
3031

3132
void CursorMapping_Release(CursorMapping *mapping) {
@@ -191,6 +192,19 @@ static void processCursorMappingCallback(MRIteratorCallbackCtx *ctx, MRReply *re
191192
MRReply_Free(rep);
192193
}
193194

195+
// Init callback for the private data, so that numShards is set to the actual number of shards in the cluster, and the expected responses.
196+
static void processCursorMappingInit(void *privateData, MRIterator *it) {
197+
processCursorMappingCallbackContext *ctx = (processCursorMappingCallbackContext *)privateData;
198+
int actualNumShards = (int)MRIterator_GetNumShards(it);
199+
pthread_mutex_lock(ctx->mutex);
200+
ctx->numShards = actualNumShards;
201+
ctx->initialized = true;
202+
ctx->errors = array_new(QueryError, actualNumShards);
203+
// Signal so the coordinator can re-check the wait condition.
204+
pthread_cond_signal(ctx->completionCond);
205+
pthread_mutex_unlock(ctx->mutex);
206+
}
207+
194208
static inline void cleanupCtx(processCursorMappingCallbackContext *ctx) {
195209
pthread_mutex_destroy(ctx->mutex);
196210
pthread_cond_destroy(ctx->completionCond);
@@ -202,7 +216,7 @@ static inline void cleanupCtx(processCursorMappingCallbackContext *ctx) {
202216
rm_free(ctx);
203217
}
204218

205-
bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef searchMappingsRef, StrongRef vsimMappingsRef, QueryError *status, const RSOomPolicy oomPolicy) {
219+
bool ProcessHybridCursorMappings(const MRCommand *cmd, StrongRef searchMappingsRef, StrongRef vsimMappingsRef, QueryError *status, const RSOomPolicy oomPolicy) {
206220
CursorMappings *searchMappings = StrongRef_Get(searchMappingsRef);
207221
CursorMappings *vsimMappings = StrongRef_Get(vsimMappingsRef);
208222
RS_ASSERT(array_len(searchMappings->mappings) == 0 && array_len(vsimMappings->mappings) == 0);
@@ -217,18 +231,22 @@ bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef
217231
pthread_cond_init(ctx->completionCond, NULL);
218232

219233
// Setup callback context
220-
*ctx = (processCursorMappingCallbackContext){
234+
*ctx = (processCursorMappingCallbackContext) {
221235
.searchMappings = StrongRef_Clone(searchMappingsRef),
222236
.vsimMappings = StrongRef_Clone(vsimMappingsRef),
223-
.errors = array_new(QueryError, numShards),
237+
.errors = NULL,
224238
.responseCount = 0,
225239
.mutex = ctx->mutex,
226240
.completionCond = ctx->completionCond,
227-
.numShards = numShards
228-
};
241+
.numShards = 0,
242+
.initialized = false
243+
};
229244

230245
// Start iteration (ctx is cleaned up manually in cleanupCtx, no destructor needed)
231-
MRIterator *it = MR_IterateWithPrivateData(cmd, processCursorMappingCallback, ctx, NULL, NULL, iterStartCb, NULL);
246+
// processCursorMappingInit is called from iterStartCb to update ctx->numShards
247+
// with the actual shard count from the live topology, preventing use-after-free
248+
// when topology changes during shard migration.
249+
MRIterator *it = MR_IterateWithPrivateData(cmd, processCursorMappingCallback, ctx, NULL, processCursorMappingInit, iterStartCb, NULL);
232250
if (!it) {
233251
// Cleanup on error
234252
QueryError_SetWithoutUserDataFmt(status, QUERY_ERROR_CODE_GENERIC, "Failed to communicate with shards");
@@ -237,8 +255,8 @@ bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef
237255
}
238256
// Wait for all callbacks to complete
239257
pthread_mutex_lock(ctx->mutex);
240-
// initialize count with response counts in case some shards already sent a response
241-
for (size_t count = ctx->responseCount; count < numShards; count = ctx->responseCount) {
258+
// Wait until the IO thread has initialized numShards and all responses arrive.
259+
while (!ctx->initialized || ctx->responseCount < ctx->numShards) {
242260
pthread_cond_wait(ctx->completionCond, ctx->mutex);
243261
}
244262
pthread_mutex_unlock(ctx->mutex);

src/coord/hybrid/hybrid_cursor_mappings.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@ typedef struct QueryError QueryError;
4242
* Handles shard errors by recording them in the status parameter while continuing to process all shards.
4343
* Returns true even if all shards fail with warnings (e.g., OOM), resulting in empty mapping arrays and allowing the caller to handle the warnings.
4444
* @param cmd The MRCommand to execute
45-
* @param numShards Expected number of shards (determines expected callbacks)
4645
* @param searchMappings Empty array to populate with search cursor mappings
4746
* @param vsimMappings Empty array to populate with vector similarity cursor mappings
4847
* @param status QueryError pointer to store warning/error information
4948
* @param oomPolicy OOM policy to determine error handling behavior
5049
* @return true if processing completed (even with warnings), false on fatal errors; status will contain error/warning information
5150
*/
52-
bool ProcessHybridCursorMappings(const MRCommand *cmd,int numShards, StrongRef searchMappings, StrongRef vsimMappings, QueryError *status, RSOomPolicy oomPolicy);
51+
bool ProcessHybridCursorMappings(const MRCommand *cmd, StrongRef searchMappings, StrongRef vsimMappings, QueryError *status, RSOomPolicy oomPolicy);
5352

5453
/**
5554
* Release resources associated with a cursor mapping

tests/pytests/test_asm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,6 @@ def test_add_shard_and_migrate_hybrid():
764764

765765
@skip(cluster=False, min_shards=2)
766766
def test_add_shard_and_migrate_hybrid_BG():
767-
# TODO: MOD-14732 - Skipped due to flaky crash (SIGSEGV) during hybrid cursor migration.
768-
raise SkipTest()
769767
env = Env(clusterNodeTimeout=cluster_node_timeout, moduleArgs='WORKERS 2')
770768
add_shard_and_migrate_test(env, 'FT.HYBRID')
771769

0 commit comments

Comments
 (0)