Skip to content

[None][fix] Release deferred ctx KV pages in V2 delay batching#13805

Merged
lancelly merged 1 commit into
NVIDIA:mainfrom
lancelly:fix/v2-revert-ctx-on-delay-batching
May 12, 2026
Merged

[None][fix] Release deferred ctx KV pages in V2 delay batching#13805
lancelly merged 1 commit into
NVIDIA:mainfrom
lancelly:fix/v2-revert-ctx-on-delay-batching

Conversation

@lancelly
Copy link
Copy Markdown
Collaborator

@lancelly lancelly commented May 6, 2026

When delay batching (_balance_adp_requests / _waiting_requests) defers context requests in the V2 scheduler path, those requests have already had their KV cache capacity grown by resize_context during scheduling. Without a revert, those freshly allocated pages sit idle for the entire wait window, blocking the rest of the pool — a real cost for long-context workloads where one deferred ctx can hold GBs of KV.

Adds revert_allocate_context, mirroring the existing revert_allocate_generation: resize_context snapshots the pre-resize cap on req.py_ctx_pre_resize_cap (None when no growth), and the executor calls revert_allocate_context for any ctx dropped by delay batching at the end of _schedule(). The shrink path on _KVCache.resize is the same one already exercised by spec-decoding rewind, so behavior and perf are known-good.

Also rename _scheduler_manages_kv_suspend to _is_kv_manager_v2 since the flag is exactly an isinstance(KVCacheManagerV2) check and is now used by both the gen and ctx revert paths, beyond the original suspend path.

The added QA tests passed locally on GB200.

Alternative considered: gate delay batching before resume

A natural variant came up: do _create_kv_cache (cheap, no GPU work) for all candidates → run delay batching on the candidate set → resume + resize only the survivors. This would skip the resume cost on deferred reqs entirely. We discussed and rejected it, for two coupled reasons.

  1. Resume failure is part of admission. kv_cache.resume() can fail on the max_util_for_resume utilization gate or on batched_lock_to_gpu OOM. V2's per-req scheduling loop uses these per-req True/False returns to drive its token / block budget across reqs — without per-req resume feedback,
    the budget loop cannot make consistent decisions for the next req.

  2. Delay batching's semantics couple compute and KV admission. Today, when delay batching "passes", the surviving set has already succeeded resume + resize, so KV admission is guaranteed. A pre-resume variant weakens this guarantee: the token-budget gate may pass 4 candidates while only 2
    actually fit on KV (resume failures filter the rest). The iter still ends up running 2 reqs and the delay-batching decision gets effectively re-arbitrated by resume failures — defeating the optimization while losing the "decision == admission" property. Note that delay batching today has
    two tunables (waiting-iter count, current-batch token count), so it's already implicitly waiting on both enough tokens and enough KV slots; splitting these out into separate pre/post-resume phases would be a semantic change, not just a reordering.

So the current shape — prepare_context (resume) → resize_context (grow) → delay batching → revert resize on dropped, with conditional suspend() when pre_cap > 0 — is the minimum-change form that preserves both invariants. The wasted resume work on deferred reqs is bounded: near-zero for
first-chunk fresh caches (empty _active_pages()), real but unavoidable for reuse / SSM / mid-chunk cases under the current admission coupling.

@lancelly lancelly marked this pull request as ready for review May 7, 2026 06:45
@lancelly lancelly requested review from a team as code owners May 7, 2026 06:45
@lancelly lancelly requested a review from joyang-nv May 7, 2026 06:45
@lancelly
Copy link
Copy Markdown
Collaborator Author

lancelly commented May 7, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47135 [ run ] triggered by Bot. Commit: 6b7a083 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 7, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

PyExecutor now uses _is_kv_manager_v2 flag to conditionally manage KV-cache capacity growth for deferred context requests. KVCacheManagerV2 adds snapshot-and-revert capability via revert_allocate_context. The scheduler detects dropped context requests during delay-batching and reverts their allocations when V2 is in use. Tests are parameterized across both manager versions.

Changes

KV-Cache Context Revert for V2 Manager

Layer / File(s) Summary
KV-Cache Snapshot & Revert Methods
tensorrt_llm/_torch/pyexecutor/resource_manager.py
KVCacheManagerV2 adds revert_allocate_context method to restore context request capacity from pre-resize snapshot (req.py_ctx_pre_resize_cap). resize_context captures pre-resize capacity when growth occurs and clears it otherwise.
KV-Manager-Version Detection
tensorrt_llm/_torch/pyexecutor/py_executor.py
PyExecutor replaces _scheduler_manages_kv_suspend with _is_kv_manager_v2 boolean, computed from isinstance(kv_cache_manager, KVCacheManagerV2) at initialization.
Conditional Gating & Revert Helper
tensorrt_llm/_torch/pyexecutor/py_executor.py
PyExecutor introduces _revert_ctx_alloc helper and gates _revert_gen_alloc on _is_kv_manager_v2. Executor-loop conditionals for pause/termination behavior are updated to use the new flag.
Scheduler Context Request Revert Logic
tensorrt_llm/_torch/pyexecutor/py_executor.py
In _schedule, preserved original_ctx_requests baseline is used to detect requests dropped by delay-batching. When V2 manager is in use, _revert_ctx_alloc is called for dropped requests before constructing ScheduledRequests.
Test Parameterization & Coverage
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/qa/llm_function_core.txt, tests/integration/test_lists/qa/llm_function_rtx6k.txt
test_nvfp4_batch_waiting gains v2_kv_cache parameter. KV-cache config passes parameter to use_kv_cache_manager_v2. Test lists enumerate both V2 and non-V2 configurations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: fixing deferred context KV page release in V2 delay batching.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description provides detailed context about the issue, solution, and design rationale, addressing the KV cache capacity leak in delay batching.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

2351-2381: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clear stale py_ctx_pre_resize_cap before attempting resize.

If kv_cache.resize(...) fails (Line 2373), req.py_ctx_pre_resize_cap is left untouched. A prior iteration’s snapshot can then be incorrectly consumed by revert_allocate_context, causing an outdated shrink target.

💡 Suggested fix
 def resize_context(self, req: LlmRequest, num_tokens: int) -> bool:
@@
         target = req.context_current_position + num_tokens + self.num_extra_kv_tokens
         capacity = max(kv_cache.capacity, target)
         pre_cap = kv_cache.capacity
+        # Clear stale snapshot from prior iterations before resize attempt.
+        req.py_ctx_pre_resize_cap = None
 
         if not kv_cache.resize(capacity):
             if req.is_first_context_chunk:
                 kv_cache.suspend()
             return False
 
-        # None means "no growth this iter, nothing to revert"; this also
-        # invalidates a stale snapshot from a prior iter on the same req.
-        req.py_ctx_pre_resize_cap = pre_cap if capacity > pre_cap else None
+        if capacity > pre_cap:
+            req.py_ctx_pre_resize_cap = pre_cap
         return True
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py` around lines 2351 - 2381,
In resize_context, clear any stale snapshot before attempting the resize by
setting req.py_ctx_pre_resize_cap = None prior to calling kv_cache.resize(...);
then compute pre_cap and on successful growth (capacity > pre_cap) set
req.py_ctx_pre_resize_cap = pre_cap, otherwise leave it None; keep the existing
kv_cache.suspend() behavior when kv_cache.resize(...) fails for
req.is_first_context_chunk and return False as before so revert_allocate_context
cannot consume an outdated snapshot.
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

1566-1566: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Also revert context growth when can_queue skips the batch.

_schedule() now unwinds ctx growth for requests dropped by delay batching, but the later skip paths still only call _revert_gen_alloc(). On KV-manager V2, resize_context has already grown scheduled_batch.context_requests, so an attention-DP skip / post-connector recheck / PP queue skip can still leave those ctx pages pinned until a future iteration. That recreates the same pool-pressure leak for skipped batches.

Suggested direction
-    def _revert_gen_alloc(self, scheduled_batch):
-        """Revert KV cache capacity growth when the batch is skipped.
+    def _revert_skipped_batch_alloc(self, scheduled_batch):
+        """Revert KV cache capacity growth when a scheduled batch is skipped.
 
-        With attention DP, can_queue=False means another rank has an empty
-        batch so no forward pass will run.  The V2 scheduler already grew
-        each generation request's KV cache capacity during scheduling;
-        revert that growth so it does not accumulate across skipped
-        iterations and overflow the host page-index buffer.
+        With KV cache manager V2, scheduling may already have grown both
+        context and generation capacity before we discover that the batch
+        cannot run on this iteration.
         """
         if self._is_kv_manager_v2:
+            for req in scheduled_batch.context_requests:
+                self.kv_cache_manager.revert_allocate_context(req)
             for req in scheduled_batch.generation_requests:
                 self.kv_cache_manager.revert_allocate_generation(req)

Then switch the not can_queue branches in PP / overlap / non-overlap to call the new helper.

Also applies to: 1951-1983, 2257-2258, 2513-2514

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` at line 1566, When a batch is
skipped by a can_queue check, also undo any context growth (resize_context)
instead of only calling _revert_gen_alloc; replace the direct calls to
self._revert_gen_alloc(scheduled_batch) in _schedule and in the PP / overlap /
non-overlap "not can_queue" branches with the new helper that unrolls context
growth (the helper introduced alongside _revert_gen_alloc that also reverts
scheduled_batch.context_requests), so skipped batches do not leave ctx pages
pinned.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Line 1566: When a batch is skipped by a can_queue check, also undo any context
growth (resize_context) instead of only calling _revert_gen_alloc; replace the
direct calls to self._revert_gen_alloc(scheduled_batch) in _schedule and in the
PP / overlap / non-overlap "not can_queue" branches with the new helper that
unrolls context growth (the helper introduced alongside _revert_gen_alloc that
also reverts scheduled_batch.context_requests), so skipped batches do not leave
ctx pages pinned.

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Around line 2351-2381: In resize_context, clear any stale snapshot before
attempting the resize by setting req.py_ctx_pre_resize_cap = None prior to
calling kv_cache.resize(...); then compute pre_cap and on successful growth
(capacity > pre_cap) set req.py_ctx_pre_resize_cap = pre_cap, otherwise leave it
None; keep the existing kv_cache.suspend() behavior when kv_cache.resize(...)
fails for req.is_first_context_chunk and return False as before so
revert_allocate_context cannot consume an outdated snapshot.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 216f91f5-f041-4231-9b0b-314f40cc5bd7

📥 Commits

Reviewing files that changed from the base of the PR and between 3e4a775 and 6b7a083.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47135 [ run ] completed with state SUCCESS. Commit: 6b7a083
/LLM/main/L0_MergeRequest_PR pipeline #37099 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@lancelly lancelly force-pushed the fix/v2-revert-ctx-on-delay-batching branch from a750f62 to 4cce44f Compare May 8, 2026 08:19
@lancelly
Copy link
Copy Markdown
Collaborator Author

lancelly commented May 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47363 [ run ] triggered by Bot. Commit: 4cce44f Link to invocation

@lancelly lancelly force-pushed the fix/v2-revert-ctx-on-delay-batching branch from 4cce44f to 150ddcb Compare May 8, 2026 08:35
@lancelly
Copy link
Copy Markdown
Collaborator Author

lancelly commented May 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47364 [ run ] triggered by Bot. Commit: 150ddcb Link to invocation

When delay batching (_balance_adp_requests / _waiting_requests) defers
context requests in the V2 scheduler path, those requests have already
had their KV cache capacity grown by resize_context during scheduling.
Without a revert, those freshly allocated pages sit idle for the entire
wait window, blocking the rest of the pool — a real cost for long-context
workloads where one deferred ctx can hold GBs of KV.

Adds revert_allocate_context, mirroring the existing
revert_allocate_generation: resize_context snapshots the pre-resize cap
on req.py_ctx_pre_resize_cap (None when no growth), and the executor
calls revert_allocate_context for any ctx dropped by delay batching at
the end of _schedule(). The shrink path on _KVCache.resize is the same
one already exercised by spec-decoding rewind, so behavior and perf are
known-good.

For non-first-chunk defers (pre_cap > 0), the revert additionally
suspends the cache so V2's pool-pressure machinery can evict the prior
chunks' pages if needed; the next iter's _resume_and_restore re-enters
through the _never_resumed=False short path, with batched_lock_to_gpu
re-locking _active_pages() only if eviction actually happened. First-
chunk defers (pre_cap == 0) skip the suspend since resize already
emptied _blocks.

Also rename _scheduler_manages_kv_suspend to _is_kv_manager_v2 since the
flag is exactly an isinstance(KVCacheManagerV2) check and is now used by
both the gen and ctx revert paths, beyond the original suspend path.

Tests:
- Parametrize test_nvfp4_batch_waiting with v2_kv_cache=[False, True] so
  the V2 path runs in CI; update qa/llm_function_core.txt and
  qa/llm_function_rtx6k.txt with the new -v2_kv_cache= test IDs.
- Add the V2 batch_waiting case to the B200 pre-merge stage.

Signed-off-by: Lanyu Liao <[email protected]>
@lancelly lancelly force-pushed the fix/v2-revert-ctx-on-delay-batching branch from 150ddcb to 1411498 Compare May 8, 2026 09:57
@lancelly
Copy link
Copy Markdown
Collaborator Author

lancelly commented May 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47380 [ run ] triggered by Bot. Commit: 1411498 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47380 [ run ] completed with state SUCCESS. Commit: 1411498
/LLM/main/L0_MergeRequest_PR pipeline #37311 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47643 [ run ] triggered by Bot. Commit: 1411498 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47643 [ run ] completed with state SUCCESS. Commit: 1411498
/LLM/main/L0_MergeRequest_PR pipeline #37548 completed with status: 'SUCCESS'

CI Report

Link to invocation

@lancelly lancelly requested a review from yizhang-nv May 11, 2026 07:35
@lancelly lancelly merged commit 7bc328f into NVIDIA:main May 12, 2026
6 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants