Skip to content

[None][feat] Add hit-rate gate and fair-share cap to KV-aware ADP router#13198

Merged
lancelly merged 1 commit into
NVIDIA:mainfrom
lancelly:kv_aware_adp_router_tune
Apr 29, 2026
Merged

[None][feat] Add hit-rate gate and fair-share cap to KV-aware ADP router#13198
lancelly merged 1 commit into
NVIDIA:mainfrom
lancelly:kv_aware_adp_router_tune

Conversation

@lancelly
Copy link
Copy Markdown
Collaborator

@lancelly lancelly commented Apr 20, 2026

Description

Tunes KVCacheAwareADPRouter so cache affinity and load balance work together instead of fighting each other. Three behaviour changes, one config surface change, and a small amount of cleanup.

1. Hit-rate gate — kv_cache_routing_match_rate_threshold (default 0.1)

For each request, match_len contributes to scoring only when max(match_len) / request_tokens across eligible ranks is strictly above the threshold; below it, match_len is forced to 0 and routing is driven purely by load. This prevents a small universal prefix (e.g. a shared system prompt) from pinning all traffic to the first warm ranks. Set to 0.0 to honour any nonzero match.

2. Fair-share cap — kv_cache_routing_fair_share_multiplier (default 2.0)

Per-rank active-request cap expressed as a multiplier of the ceil fair-share, i.e. fair_share_multiplier * ceil((total_active + new) / tp_size). Once a rank hits the cap within a scheduling batch it is dropped from the eligible set for the rest of that batch. 2.0 leaves enough slack for affinity to win in the common case while preventing runaway concentration; set to 1.0 for strict fair share.

3. Transfer-in-progress load accounting

Requests mid-KV-transfer to GEN are no longer visible in active_requests, so the router used to under-count load on the rank that is sending. The router now pulls requests_in_transfer() from the PyExecutor's AsyncTransferManager and folds those requests into both num_active_requests and num_active_tokens when building each rank's RankState.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added two configuration parameters to control KV cache-aware routing behavior: kv_cache_routing_match_rate_threshold and kv_cache_routing_fair_share_multiplier.
  • Improvements

    • Optimized request routing algorithm to better utilize KV cache affinity while maintaining balanced load distribution across distributed ranks.
    • Improved handling of asynchronous request transfers during routing initialization.

lancelly pushed a commit to lancelly/TensorRT-LLM that referenced this pull request Apr 21, 2026
…ounting to KV-aware ADP router

Three env-gated additions to the KV-cache-aware ADP router that address
cold-start load imbalance without touching existing code paths.  All
default off; behaviour is unchanged unless env vars are set.

1. TLLM_ADP_ROUTER_MATCH_RATE_THRESHOLD (float, default 0.0):
   Gate cache affinity in scoring when the best available hit rate
   (max match_len / req_tokens across eligible ranks) is at or below the
   threshold.  With threshold=0.10, first turns of new conversations
   (where match_len is only the shared thinking-template prefix) fall
   through to pure load-balanced routing and seed cold ranks with fresh
   trajectories instead of piling on ranks that happened to cache the
   template first.  Subsequent multi-turn requests still honour cache
   affinity because their hit rate far exceeds the threshold.

2. TLLM_ADP_ROUTER_RANDOMIZE_TIEBREAK (0/1, default 0):
   Iterate eligible ranks in a per-decision random order so score/
   active_tokens ties are resolved by uniform random pick instead of the
   default "lowest-index wins".  Seeded deterministically with req_id so
   every TP rank produces the same shuffle and routing decisions stay
   consistent across ranks.  Primarily helps gate_off events during cold
   start where multiple ranks have zero load.

3. TLLM_ADP_ROUTER_INCLUDE_TRANSFER_LOAD (0/1, default 0):
   Include KV-transfer-in-progress requests (held by
   AsyncTransferManager after prefill completes) in per-rank load
   accounting.  Without this, disaggregated CTX workers appear idle
   between prefills and the router concentrates subsequent requests on
   already-busy ranks.  PyExecutor now injects its async_transfer_manager
   into the router after both are constructed.

Also extends the diagnostic log already added in PR NVIDIA#13198:
- cache_affinity_active, max_match_for_req in adp_router_v2_decision
- match_rate_threshold, randomize_tiebreak in adp_router_v2_batch
- adp_router_v2_rank_state (per-rank state snapshot)
- adp_router_v2_pyexec_snapshot (pyexec-global state snapshot)

On DSV3.2 1P1D at concurrency 24 with threshold=0.10 + randomize=1, rank
coverage improves from baseline min/mean = 0.18 % (rank 6/7 serving 1
request each out of ~570 average) to 52 % (rank 6/7 serving 312 and 779
respectively), with TTFT p90 within run-to-run noise and cache hit rate
unchanged at 96 %.

Signed-off-by: Lance Liao <[email protected]>
@lancelly lancelly force-pushed the kv_aware_adp_router_tune branch from 7731e41 to 640bdae Compare April 22, 2026 10:18
@lancelly lancelly changed the title [None][fix] Kv aware adp router opitimization [None][feat] Add hit-rate gate and fair-share cap to KV-aware ADP router Apr 22, 2026
@lancelly lancelly force-pushed the kv_aware_adp_router_tune branch from 640bdae to 07df92b Compare April 22, 2026 10:40
@lancelly lancelly marked this pull request as ready for review April 22, 2026 10:41
@lancelly lancelly requested review from a team as code owners April 22, 2026 10:41
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44959 [ run ] triggered by Bot. Commit: 07df92b Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

The pull request enhances the ADP routing system with cache-affinity awareness and configurable routing thresholds. Changes include reordering PyExecutor initialization to pass the AsyncTransferManager to the ADPRouter, extending KVCacheAwareADPRouter with match-rate and fair-share controls, updating active-request computation to account for remaining-to-compute tokens and in-flight transfers, and implementing per-request rank shuffling with cache-affinity gating. New routing configuration parameters and corresponding tests are introduced.

Changes

Cohort / File(s) Summary
Configuration
tensorrt_llm/llmapi/llm_args.py
Added two new AttentionDpConfig parameters: kv_cache_routing_match_rate_threshold (default 0.1) and kv_cache_routing_fair_share_multiplier (default 2.0) to control cache-aware routing behavior.
Core Routing Logic
tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py
Extended ADPRouter.create() to accept and forward async_transfer_manager. Updated KVCacheAwareADPRouter.__init__ to store routing thresholds and transfer manager. Modified create_rank_state() to compute active requests/tokens using remaining-to-compute tokens and in-flight requests. Reworked route_requests() with per-request rank shuffling, cache-affinity gating, and fair-share multiplier constraints. Added helper methods for token counting and prefix-match lookup.
Initialization Integration
tensorrt_llm/_torch/pyexecutor/py_executor.py
Reordered ADPRouter construction to occur after AsyncTransferManager initialization; extended router creation call to pass async_transfer_manager parameter.
Testing
tests/unittest/_torch/executor/test_adp_router.py
Added TestKVCacheAwareADPRouterRouting test class with three test scenarios validating cache-affinity routing, match-rate-threshold suppression, and fair-share-multiplier request capping behavior.

Sequence Diagram

sequenceDiagram
    participant PyExec as PyExecutor
    participant ATM as AsyncTransferManager
    participant Router as ADPRouter
    participant KVRouter as KVCacheAwareADPRouter
    participant KVMgr as KVCacheManager
    
    PyExec->>ATM: Initialize
    PyExec->>Router: create(async_transfer_manager=ATM)
    Router->>KVRouter: new(match_rate_threshold, fair_share_multiplier, async_transfer_manager)
    
    Note over KVRouter: Routing Request Received
    KVRouter->>KVMgr: Query prefix matches & cached tokens
    KVRouter->>ATM: Count requests in transfer
    KVRouter->>KVRouter: Compute remaining-to-compute tokens
    KVRouter->>KVRouter: Apply cache-affinity gate (match_rate_threshold)
    KVRouter->>KVRouter: Shuffle eligible ranks per request
    KVRouter->>KVRouter: Score ranks (affinity + load balance)
    KVRouter->>KVRouter: Apply fair_share_multiplier cap
    KVRouter-->>PyExec: Route decision (rank assignment)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main changes: adding hit-rate gate and fair-share cap to KV-aware ADP router.
Description check ✅ Passed The PR description adequately covers the three main behavior changes, one config surface change, and provides clear explanations of each feature with default values and use cases.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py`:
- Around line 541-547: The code computes expected_num_active_requests using
int(self.fair_share_multiplier * fair_share) which truncates float multipliers;
replace the truncation with ceiling to implement the intended "round up"
behavior: use math.ceil(self.fair_share_multiplier * fair_share) (add an import
for math if needed) in the calculation for expected_num_active_requests in
adp_router where fair_share and self.fair_share_multiplier are used, and add a
regression test that constructs a scenario with fair_share==1 and
fair_share_multiplier==1.5 (or another non-integer) to assert the cap allows 2
requests rather than 1 and that routing/eviction behavior matches the
multiplier.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 624-646: The two new routing knobs lack validation: constrain
kv_cache_routing_match_rate_threshold to be between 0.0 and 1.0 (use Field(...,
ge=0.0, le=1.0)) and constrain kv_cache_routing_fair_share_multiplier to be at
least 1.0 (use Field(..., ge=1.0)); update the Field definitions for the
variables named kv_cache_routing_match_rate_threshold and
kv_cache_routing_fair_share_multiplier in llm_args.py to include these bounds so
Pydantic will reject semantically invalid configs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 0b6cbfb8-e4b5-4b2d-9a06-130fd0d7b504

📥 Commits

Reviewing files that changed from the base of the PR and between 62ce575 and 07df92b.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/executor/test_adp_router.py

Comment thread tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py
Comment thread tensorrt_llm/llmapi/llm_args.py
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45142 [ run ] triggered by Bot. Commit: fa796af Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45142 [ run ] completed with state FAILURE. Commit: fa796af

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45187 [ run ] triggered by Bot. Commit: fa796af Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45187 [ run ] completed with state SUCCESS. Commit: fa796af
/LLM/main/L0_MergeRequest_PR pipeline #35461 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly lancelly force-pushed the kv_aware_adp_router_tune branch from fa796af to 3660223 Compare April 24, 2026 11:34
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45388 [ run ] triggered by Bot. Commit: 3660223 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45388 [ run ] completed with state SUCCESS. Commit: 3660223
/LLM/main/L0_MergeRequest_PR pipeline #35629 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45575 [ run ] triggered by Bot. Commit: 3660223 Link to invocation

@lancelly lancelly force-pushed the kv_aware_adp_router_tune branch from 3660223 to a73ad8a Compare April 26, 2026 13:44
Expose match_rate_threshold and fair_share_multiplier as AttentionDpConfig
fields. Add a hit-rate gate with random tiebreak, always account
KV-transfer-in-progress requests in router load, and enforce a 2x
fair-share cap on per-rank token load. Clean up debug logging to a single
per-batch line, simplify rank-state handling, and extend unit tests to
cover the new config fields and direct routing behavior.

Signed-off-by: Lanyu Liao <[email protected]>
@lancelly lancelly force-pushed the kv_aware_adp_router_tune branch from a73ad8a to f324316 Compare April 26, 2026 15:08
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45588 [ run ] triggered by Bot. Commit: f324316 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45588 [ run ] completed with state SUCCESS. Commit: f324316
/LLM/main/L0_MergeRequest_PR pipeline #35804 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45606 [ run ] triggered by Bot. Commit: f324316 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45606 [ run ] completed with state SUCCESS. Commit: f324316
/LLM/main/L0_MergeRequest_PR pipeline #35821 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45646 [ run ] triggered by Bot. Commit: f324316 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45646 [ run ] completed with state SUCCESS. Commit: f324316
/LLM/main/L0_MergeRequest_PR pipeline #35859 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@lishicheng1996-nv lishicheng1996-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, it works great!

Comment thread tensorrt_llm/llmapi/llm_args.py
Copy link
Copy Markdown
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lancelly lancelly merged commit d8bb9e6 into NVIDIA:main Apr 29, 2026
5 checks passed
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull request May 4, 2026
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants