Skip to content

Commit 845231d

Browse files
JoanFMgithub-actions[bot]
authored andcommitted
[MOD-14732] fix race condition on hybrid (#9029)
* fix race condition on hybrid * remove unrelated file * use Strong/WeakRef instead of refCount * fix the real issue, count proper numShards * fix by setting numShards from IO thread * rephrase * fix potential problem if shards can be 0 * use initialized flag * update comment * remove unused numShards (cherry picked from commit 3bb5f03)
1 parent 2b7a93a commit 845231d

3 files changed

Lines changed: 31 additions & 13 deletions

File tree

src/coord/hybrid/dist_hybrid.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ void HybridRequest_buildMRCommand(RedisModuleString **argv, int argc,
253253
MRCommand_appendVsim(xcmd, argv, argc, vsimOffset, &kArgIndex);
254254

255255
// Calculate and apply effective K for KNN queries if SHARD_K_RATIO is set
256+
// TODO: Potentially edit in IO thread where numShards is actually known.
257+
// Now we have a risk that by the time I/O thread sends the command, the number of shards changed, making the effective K inaccurate.
256258
if (vq && vq->type == VECSIM_QT_KNN) {
257259
double shardWindowRatio = vq->knn.shardWindowRatio;
258260
if (shardWindowRatio < MAX_SHARD_WINDOW_RATIO && numShards > 1) {
@@ -702,11 +704,10 @@ static int HybridRequest_executePlan(HybridRequest *hreq, struct ConcurrentCmdCt
702704

703705
// Get the command from the RPNet (it was set during prepareForExecution)
704706
MRCommand *cmd = &searchRPNet->cmd;
705-
int numShards = ConcurrentCmdCtx_GetNumShards(cmdCtx);
706707
cmd->coordStartTime = hreq->profileClocks.coordStartTime;
707708

708709
const RSOomPolicy oomPolicy = hreq->reqConfig.oomPolicy;
709-
if (!ProcessHybridCursorMappings(cmd, numShards, searchMappingsRef, vsimMappingsRef, hreq->tailPipeline->qctx.err, oomPolicy)) {
710+
if (!ProcessHybridCursorMappings(cmd, searchMappingsRef, vsimMappingsRef, hreq->tailPipeline->qctx.err, oomPolicy)) {
710711
// Handle error
711712
StrongRef_Release(searchMappingsRef);
712713
StrongRef_Release(vsimMappingsRef);
@@ -800,7 +801,7 @@ void RSExecDistHybrid(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
800801
// Store coordinator start time for dispatch time tracking
801802
hreq->profileClocks.coordStartTime = ConcurrentCmdCtx_GetCoordStartTime(cmdCtx);
802803

803-
// Get numShards captured from main thread for thread-safe access
804+
// Get numShards captured from main thread for thread-safe access and to compute effective K
804805
size_t numShards = ConcurrentCmdCtx_GetNumShards(cmdCtx);
805806

806807
if (HybridRequest_prepareForExecution(hreq, ctx, argv, argc, sp, numShards, &status) != REDISMODULE_OK) {

src/coord/hybrid/hybrid_cursor_mappings.c

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ typedef struct {
2626
pthread_mutex_t *mutex; // Mutex for array access and completion tracking
2727
pthread_cond_t *completionCond; // Condition variable for completion signaling
2828
int numShards; // Total number of expected shards
29+
bool initialized; // Whether numShards has been set by the IO thread
2930
} processCursorMappingCallbackContext;
3031

3132
void CursorMapping_Release(CursorMapping *mapping) {
@@ -189,6 +190,19 @@ static void processCursorMappingCallback(MRIteratorCallbackCtx *ctx, MRReply *re
189190
MRReply_Free(rep);
190191
}
191192

193+
// Init callback for the private data, so that numShards is set to the actual number of shards in the cluster, and the expected responses.
194+
static void processCursorMappingInit(void *privateData, MRIterator *it) {
195+
processCursorMappingCallbackContext *ctx = (processCursorMappingCallbackContext *)privateData;
196+
int actualNumShards = (int)MRIterator_GetNumShards(it);
197+
pthread_mutex_lock(ctx->mutex);
198+
ctx->numShards = actualNumShards;
199+
ctx->initialized = true;
200+
ctx->errors = array_new(QueryError, actualNumShards);
201+
// Signal so the coordinator can re-check the wait condition.
202+
pthread_cond_signal(ctx->completionCond);
203+
pthread_mutex_unlock(ctx->mutex);
204+
}
205+
192206
static inline void cleanupCtx(processCursorMappingCallbackContext *ctx) {
193207
pthread_mutex_destroy(ctx->mutex);
194208
pthread_cond_destroy(ctx->completionCond);
@@ -200,7 +214,7 @@ static inline void cleanupCtx(processCursorMappingCallbackContext *ctx) {
200214
rm_free(ctx);
201215
}
202216

203-
bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef searchMappingsRef, StrongRef vsimMappingsRef, QueryError *status, const RSOomPolicy oomPolicy) {
217+
bool ProcessHybridCursorMappings(const MRCommand *cmd, StrongRef searchMappingsRef, StrongRef vsimMappingsRef, QueryError *status, const RSOomPolicy oomPolicy) {
204218
CursorMappings *searchMappings = StrongRef_Get(searchMappingsRef);
205219
CursorMappings *vsimMappings = StrongRef_Get(vsimMappingsRef);
206220
RS_ASSERT(array_len(searchMappings->mappings) == 0 && array_len(vsimMappings->mappings) == 0);
@@ -215,18 +229,22 @@ bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef
215229
pthread_cond_init(ctx->completionCond, NULL);
216230

217231
// Setup callback context
218-
*ctx = (processCursorMappingCallbackContext){
232+
*ctx = (processCursorMappingCallbackContext) {
219233
.searchMappings = StrongRef_Clone(searchMappingsRef),
220234
.vsimMappings = StrongRef_Clone(vsimMappingsRef),
221-
.errors = array_new(QueryError, numShards),
235+
.errors = NULL,
222236
.responseCount = 0,
223237
.mutex = ctx->mutex,
224238
.completionCond = ctx->completionCond,
225-
.numShards = numShards
226-
};
239+
.numShards = 0,
240+
.initialized = false
241+
};
227242

228243
// Start iteration (ctx is cleaned up manually in cleanupCtx, no destructor needed)
229-
MRIterator *it = MR_IterateWithPrivateData(cmd, processCursorMappingCallback, ctx, NULL, NULL, iterStartCb, NULL);
244+
// processCursorMappingInit is called from iterStartCb to update ctx->numShards
245+
// with the actual shard count from the live topology, preventing use-after-free
246+
// when topology changes during shard migration.
247+
MRIterator *it = MR_IterateWithPrivateData(cmd, processCursorMappingCallback, ctx, NULL, processCursorMappingInit, iterStartCb, NULL);
230248
if (!it) {
231249
// Cleanup on error
232250
QueryError_SetWithoutUserDataFmt(status, QUERY_ERROR_CODE_GENERIC, "Failed to communicate with shards");
@@ -235,8 +253,8 @@ bool ProcessHybridCursorMappings(const MRCommand *cmd, int numShards, StrongRef
235253
}
236254
// Wait for all callbacks to complete
237255
pthread_mutex_lock(ctx->mutex);
238-
// initialize count with response counts in case some shards already sent a response
239-
for (size_t count = ctx->responseCount; count < numShards; count = ctx->responseCount) {
256+
// Wait until the IO thread has initialized numShards and all responses arrive.
257+
while (!ctx->initialized || ctx->responseCount < ctx->numShards) {
240258
pthread_cond_wait(ctx->completionCond, ctx->mutex);
241259
}
242260
pthread_mutex_unlock(ctx->mutex);

src/coord/hybrid/hybrid_cursor_mappings.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@ typedef struct QueryError QueryError;
4242
* Handles shard errors by recording them in the status parameter while continuing to process all shards.
4343
* Returns true even if all shards fail with warnings (e.g., OOM), resulting in empty mapping arrays and allowing the caller to handle the warnings.
4444
* @param cmd The MRCommand to execute
45-
* @param numShards Expected number of shards (determines expected callbacks)
4645
* @param searchMappings Empty array to populate with search cursor mappings
4746
* @param vsimMappings Empty array to populate with vector similarity cursor mappings
4847
* @param status QueryError pointer to store warning/error information
4948
* @param oomPolicy OOM policy to determine error handling behavior
5049
* @return true if processing completed (even with warnings), false on fatal errors; status will contain error/warning information
5150
*/
52-
bool ProcessHybridCursorMappings(const MRCommand *cmd,int numShards, StrongRef searchMappings, StrongRef vsimMappings, QueryError *status, RSOomPolicy oomPolicy);
51+
bool ProcessHybridCursorMappings(const MRCommand *cmd, StrongRef searchMappings, StrongRef vsimMappings, QueryError *status, RSOomPolicy oomPolicy);
5352

5453
/**
5554
* Release resources associated with a cursor mapping

0 commit comments

Comments
 (0)