Skip to content

Commit 580c7a6

Browse files
authored
[TRTLLM-11421][feat] Support better kv cache statistics monitoring (#12413)
Signed-off-by: Yueh-Ting Chen <[email protected]>
1 parent 4811704 commit 580c7a6

21 files changed

Lines changed: 2009 additions & 45 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,42 @@ struct KvCacheStats
259259
std::size_t allocatedBytes{};
260260
};
261261

262+
/// @brief Per-iteration KV cache statistics. All delta counters represent changes since the last call to
263+
/// getIterationStats(). Gauges are instantaneous snapshots.
264+
struct KvCacheIterationStats
265+
{
266+
// --- Instantaneous gauges ---
267+
// Primary (GPU) pool
268+
SizeType32 primaryMaxNumBlocks{0};
269+
SizeType32 primaryFreeNumBlocks{0};
270+
SizeType32 primaryUsedNumBlocks{0};
271+
// Secondary (host) pool
272+
SizeType32 secondaryMaxNumBlocks{0};
273+
SizeType32 secondaryFreeNumBlocks{0};
274+
SizeType32 secondaryUsedNumBlocks{0};
275+
276+
// --- Per-iteration deltas (reset on each read) ---
277+
// Context phase: block allocation and reuse
278+
SizeType32 iterAllocTotalBlocks{0};
279+
SizeType32 iterAllocNewBlocks{0};
280+
SizeType32 iterReusedBlocks{0}; // = iterFullReusedBlocks + iterPartialReusedBlocks
281+
SizeType32 iterFullReusedBlocks{0}; // blocks fully matched in radix tree
282+
SizeType32 iterPartialReusedBlocks{0}; // blocks partially matched in radix tree
283+
SizeType32 iterMissedBlocks{0};
284+
float iterCacheHitRate{0.0f};
285+
// Generation phase: block allocation
286+
SizeType32 iterGenAllocBlocks{0};
287+
288+
// Transfer traffic deltas — host ↔ GPU
289+
SizeType32 iterOnboardBlocks{0};
290+
std::size_t iterOnboardBytes{0};
291+
SizeType32 iterOffloadBlocks{0};
292+
std::size_t iterOffloadBytes{0};
293+
// Intra-device (GPU → GPU) block copies (e.g. partial reuse when source block has refs)
294+
SizeType32 iterIntraDeviceCopyBlocks{0};
295+
std::size_t iterIntraDeviceCopyBytes{0};
296+
};
297+
262298
// Basic building block of a paged KV cache - a single
263299
// cache block. This class just holds metadata, no pointers
264300
// since it is reused across all layers.
@@ -815,6 +851,12 @@ class WindowBlockManager
815851
return mMissedBlocks;
816852
}
817853

854+
// Get num free blocks in the secondary (host) memory pool
855+
[[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const noexcept;
856+
857+
//! \brief Get iteration stats (deltas since last call) for this window. Resets internal delta snapshots.
858+
[[nodiscard]] KvCacheIterationStats getAndResetIterationStats();
859+
818860
[[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const
819861
{
820862
return getNumFreeBlocks() >= numRequired;
@@ -1128,16 +1170,22 @@ class WindowBlockManager
11281170
std::shared_ptr<KVCacheTransferManager> mTransferManager;
11291171

11301172
// Statistics for block allocations/reuse
1131-
// Total number of blocks allocated by all requests
1173+
// Total number of blocks allocated by all requests (context phase)
11321174
SizeType32 mAllocTotalBlocks;
1133-
// Number of new blocks that were allocated
1175+
// Number of new blocks that were allocated (context phase)
11341176
SizeType32 mAllocNewBlocks;
1135-
// Number of blocks that were reused
1177+
// Number of blocks that were fully reused (context phase)
1178+
SizeType32 mFullReusedBlocks;
1179+
// Number of blocks that were partially reused (context phase)
1180+
SizeType32 mPartialReusedBlocks;
1181+
// Number of blocks that were reused (full + partial, context phase)
11361182
SizeType32 mReusedBlocks;
11371183
// Number of unique blocks that were reused
11381184
SizeType32 mReusedUniqueBlocks;
1139-
// Number of blocks that were not reused
1185+
// Number of blocks that were not reused (context phase)
11401186
SizeType32 mMissedBlocks;
1187+
// Number of blocks allocated during generation phase
1188+
SizeType32 mGenAllocBlocks;
11411189
// Only be 1 or 2. If 2: general KV stored. If 1: K == V for any token, so only K is stored to optimize the
11421190
// max_num_tokens(For DeepSeek). Controlled by mCacheType
11431191
SizeType32 mKVFactor;
@@ -1154,6 +1202,15 @@ class WindowBlockManager
11541202
// The kv cache connector manager
11551203
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
11561204

1205+
// Snapshot of cumulative counters at last iteration stats read (for delta computation)
1206+
SizeType32 mPrevAllocTotalBlocks{0};
1207+
SizeType32 mPrevAllocNewBlocks{0};
1208+
SizeType32 mPrevReusedBlocks{0};
1209+
SizeType32 mPrevFullReusedBlocks{0};
1210+
SizeType32 mPrevPartialReusedBlocks{0};
1211+
SizeType32 mPrevMissedBlocks{0};
1212+
SizeType32 mPrevGenAllocBlocks{0};
1213+
11571214
// Mutex for the cached blocks root
11581215
mutable std::mutex mCachedBlocksRootMutex;
11591216

@@ -1359,6 +1416,19 @@ class BlockManager
13591416
return sumWindows([](auto const& manager) { return manager.getNumMissedBlocks(); });
13601417
}
13611418

1419+
[[nodiscard]] SizeType32 getNumSecondaryBlocks() const
1420+
{
1421+
return sumWindows([](auto const& manager) { return manager.getNumSecondaryBlocks(); });
1422+
}
1423+
1424+
[[nodiscard]] SizeType32 getNumFreeSecondaryBlocks() const
1425+
{
1426+
return sumWindows([](auto const& manager) { return manager.getNumFreeSecondaryBlocks(); });
1427+
}
1428+
1429+
//! \brief Get per-window-size iteration stats. Resets delta snapshots for each window.
1430+
[[nodiscard]] std::map<SizeType32, KvCacheIterationStats> getAndResetIterationStats();
1431+
13621432
[[nodiscard]] SizeType32 getNumLayers() const
13631433
{
13641434
return mNumLayers;
@@ -1688,6 +1758,10 @@ class BaseKVCacheManager
16881758

16891759
[[nodiscard]] virtual KvCacheStats getKvCacheStats() const = 0;
16901760

1761+
//! \brief Get per-iteration stats with delta counters, keyed by window size.
1762+
//! Resets delta snapshots on each call.
1763+
[[nodiscard]] virtual std::map<SizeType32, KvCacheIterationStats> getIterationStats() = 0;
1764+
16911765
[[nodiscard]] virtual OffsetTableDimensions getOffsetTableDimensions() const = 0;
16921766

16931767
[[nodiscard]] virtual std::deque<executor::KVCacheEvent> getLatestEvents(
@@ -2046,6 +2120,11 @@ class KVCacheManager : public BaseKVCacheManager
20462120
return kvCacheStats;
20472121
}
20482122

2123+
[[nodiscard]] std::map<SizeType32, KvCacheIterationStats> getIterationStats() override
2124+
{
2125+
return mBlockManager.getAndResetIterationStats();
2126+
}
2127+
20492128
[[nodiscard]] OffsetTableDimensions getOffsetTableDimensions() const override
20502129
{
20512130
OffsetTableDimensions dims;

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ namespace kvc = tensorrt_llm::executor::kv_cache;
2727
namespace tensorrt_llm::batch_manager::kv_cache_manager
2828
{
2929

30+
/// @brief Statistics for block transfers. Returned by KVCacheTransferManager::getAndResetTransferStats().
31+
/// All counters are reset on read.
32+
/// - onboard/offload: transfers between secondary (host) and primary (GPU) memory.
33+
/// - intraDeviceCopy: GPU-to-GPU block copies (e.g. partial reuse when source block has refs).
34+
struct KvCacheTransferStats
35+
{
36+
SizeType32 onboardBlocks{0};
37+
std::size_t onboardBytes{0};
38+
SizeType32 offloadBlocks{0};
39+
std::size_t offloadBytes{0};
40+
SizeType32 intraDeviceCopyBlocks{0};
41+
std::size_t intraDeviceCopyBytes{0};
42+
};
43+
3044
// The TransferManager accelerates transfers to/from the GPU by overlapping HtoD and DtoH transfers, and tracks ongoing
3145
// transfers in order to avoid race conditions. It is functionally equivalent to the prior approach of putting all
3246
// transfers into the forward pass stream. This is only ever used as a component of a KVCacheManager.
@@ -57,6 +71,9 @@ class KVCacheTransferManager
5771
//! must be called after last call to KVCacheManager::addSequence in every step.
5872
void syncTransfers();
5973

74+
//! \brief Get transfer stats accumulated since last call, and reset the counters.
75+
[[nodiscard]] KvCacheTransferStats getAndResetTransferStats();
76+
6077
private:
6178
//! \brief Get pointer to pool specified by cache block.
6279
static tr::ITensor::SharedPtr computeBlockPointer(
@@ -79,6 +96,12 @@ class KVCacheTransferManager
7996
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
8097
std::string const& directory = "");
8198

99+
//! \brief Compute total bytes actually transferred for a block copy across all pools.
100+
//! \param pools The pool descriptors.
101+
//! \param numTokensToCopy Number of tokens for partial copy (0 means full block).
102+
[[nodiscard]] std::size_t computeBlockTransferBytes(
103+
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy) const;
104+
82105
runtime::BufferManager mBufferManager;
83106
runtime::BufferManager mOnboardManager;
84107
runtime::BufferManager mOffloadManager;
@@ -90,6 +113,16 @@ class KVCacheTransferManager
90113
// Reference to parent loopback agent
91114
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
92115
int mDeviceId;
116+
117+
// Cumulative transfer statistics, reset on each call to getAndResetTransferStats().
118+
// Protected by mStatsMutex for thread-safe access.
119+
mutable std::mutex mStatsMutex;
120+
SizeType32 mOnboardBlockCount{0};
121+
std::size_t mOnboardByteCount{0};
122+
SizeType32 mOffloadBlockCount{0};
123+
std::size_t mOffloadByteCount{0};
124+
SizeType32 mIntraDeviceCopyBlockCount{0};
125+
std::size_t mIntraDeviceCopyByteCount{0};
93126
};
94127

95128
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,9 +760,12 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
760760
, mTransferManager{std::make_shared<KVCacheTransferManager>(mBufferManager, mLoopbackAgent)}
761761
, mAllocTotalBlocks{0}
762762
, mAllocNewBlocks{0}
763+
, mFullReusedBlocks{0}
764+
, mPartialReusedBlocks{0}
763765
, mReusedBlocks{0}
764766
, mReusedUniqueBlocks{0}
765767
, mMissedBlocks{0}
768+
, mGenAllocBlocks{0}
766769
, mKVFactor{(mCacheType == CacheType::kSELFKONLY
767770
|| (linearAttentionMetadata.has_value() && linearAttentionMetadata->hasRecurrentStatesCache()))
768771
? 1
@@ -1518,6 +1521,14 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
15181521
reusedBlockIds.insert(matchingBlockId);
15191522
++mReusedUniqueBlocks;
15201523
}
1524+
if (partialMatch)
1525+
{
1526+
++mPartialReusedBlocks;
1527+
}
1528+
else
1529+
{
1530+
++mFullReusedBlocks;
1531+
}
15211532
}
15221533
++blockItr;
15231534
}
@@ -1726,6 +1737,7 @@ void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence)
17261737
{
17271738
// Allocating a new block when the last token is a block boundary
17281739
allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1);
1740+
++mGenAllocBlocks;
17291741
updateLastCacheBlockOffsets(sequence);
17301742
}
17311743
}
@@ -2226,6 +2238,73 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence)
22262238
return numFree;
22272239
}
22282240

2241+
[[nodiscard]] SizeType32 WindowBlockManager::getNumFreeSecondaryBlocks() const noexcept
2242+
{
2243+
return mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel);
2244+
}
2245+
2246+
KvCacheIterationStats WindowBlockManager::getAndResetIterationStats()
2247+
{
2248+
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
2249+
KvCacheIterationStats stats;
2250+
2251+
// Instantaneous gauges
2252+
stats.primaryMaxNumBlocks = getNumPrimaryBlocks();
2253+
stats.primaryFreeNumBlocks = getNumFreeBlocks();
2254+
stats.primaryUsedNumBlocks = stats.primaryMaxNumBlocks - stats.primaryFreeNumBlocks;
2255+
stats.secondaryMaxNumBlocks = getNumSecondaryBlocks();
2256+
stats.secondaryFreeNumBlocks = getNumFreeSecondaryBlocks();
2257+
stats.secondaryUsedNumBlocks = stats.secondaryMaxNumBlocks - stats.secondaryFreeNumBlocks;
2258+
2259+
// Compute deltas since last call — context phase
2260+
stats.iterAllocTotalBlocks = mAllocTotalBlocks - mPrevAllocTotalBlocks;
2261+
stats.iterAllocNewBlocks = mAllocNewBlocks - mPrevAllocNewBlocks;
2262+
stats.iterReusedBlocks = mReusedBlocks - mPrevReusedBlocks;
2263+
stats.iterFullReusedBlocks = mFullReusedBlocks - mPrevFullReusedBlocks;
2264+
stats.iterPartialReusedBlocks = mPartialReusedBlocks - mPrevPartialReusedBlocks;
2265+
stats.iterMissedBlocks = mMissedBlocks - mPrevMissedBlocks;
2266+
2267+
auto const iterTotal = stats.iterReusedBlocks + stats.iterMissedBlocks;
2268+
stats.iterCacheHitRate
2269+
= iterTotal == 0 ? 0.0f : static_cast<float>(stats.iterReusedBlocks) / static_cast<float>(iterTotal);
2270+
2271+
// Generation phase
2272+
stats.iterGenAllocBlocks = mGenAllocBlocks - mPrevGenAllocBlocks;
2273+
2274+
// Snapshot current values for next delta
2275+
mPrevAllocTotalBlocks = mAllocTotalBlocks;
2276+
mPrevAllocNewBlocks = mAllocNewBlocks;
2277+
mPrevReusedBlocks = mReusedBlocks;
2278+
mPrevFullReusedBlocks = mFullReusedBlocks;
2279+
mPrevPartialReusedBlocks = mPartialReusedBlocks;
2280+
mPrevMissedBlocks = mMissedBlocks;
2281+
mPrevGenAllocBlocks = mGenAllocBlocks;
2282+
2283+
// Transfer stats (collected from transfer manager)
2284+
if (mTransferManager)
2285+
{
2286+
auto transferStats = mTransferManager->getAndResetTransferStats();
2287+
stats.iterOnboardBlocks = transferStats.onboardBlocks;
2288+
stats.iterOnboardBytes = transferStats.onboardBytes;
2289+
stats.iterOffloadBlocks = transferStats.offloadBlocks;
2290+
stats.iterOffloadBytes = transferStats.offloadBytes;
2291+
stats.iterIntraDeviceCopyBlocks = transferStats.intraDeviceCopyBlocks;
2292+
stats.iterIntraDeviceCopyBytes = transferStats.intraDeviceCopyBytes;
2293+
}
2294+
2295+
return stats;
2296+
}
2297+
2298+
std::map<SizeType32, KvCacheIterationStats> BlockManager::getAndResetIterationStats()
2299+
{
2300+
std::map<SizeType32, KvCacheIterationStats> perWindowStats;
2301+
for (auto& [windowSize, manager] : mWindowBlockManagers)
2302+
{
2303+
perWindowStats[windowSize] = manager.getAndResetIterationStats();
2304+
}
2305+
return perWindowStats;
2306+
}
2307+
22292308
std::deque<tle::KVCacheEvent> BlockManager::getLatestEvents(std::optional<std::chrono::milliseconds> timeout) const
22302309
{
22312310
return mEventManager ? mEventManager->getEvents(timeout) : std::deque<tle::KVCacheEvent>{};

0 commit comments

Comments
 (0)