[None][feat] Add MegaMoEDeepGemmFusedMoE backend wrapping DeepGEMM fp8_fp4_mega_moe#13384
Conversation
124d4f0 to
b652f19
Compare
|
Discussed with @Barry-Delaney. Considering that Barry has tested the functionality locally and we are in a hurry for performance, it is suggested to rename the newly - added backend to MEGAMOE_DEEPGEMM. This is because we are developing our own MEGA kernels. |
b652f19 to
5112801
Compare
70c4946 to
4ce2c3c
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis PR introduces a new MegaMoE backend for fused MoE operations powered by DeepGEMM, updating the deepgemm dependency to a newer commit and integrating the backend into the configurable MoE framework. Changes
Sequence DiagramsequenceDiagram
participant User as User
participant CMoE as ConfigurableMoE
participant Router as Router
participant Backend as MegaMoEDeepGemmFusedMoE
participant Quant as Quantizer
participant DG as DeepGEMM Kernel
User->>CMoE: forward(x, router_logits)
CMoE->>CMoE: _forward_chunk_mega_impl()
CMoE->>Backend: count tokens per rank
CMoE->>CMoE: slice x, router_logits to real tokens
CMoE->>Router: apply routing
CMoE->>Router: topk casting
CMoE->>Backend: quantize_input(x)
Backend->>Quant: mxfp8_quantize(BF16→FP8)
Quant-->>Backend: x_fp8, x_sf
CMoE->>Backend: copy to DeepGEMM SymmBuffer
CMoE->>DG: fp8_fp4_mega_moe (collective kernel)
DG-->>CMoE: output (FP32)
CMoE-->>User: return BF16 combined output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py (3)
209-209: Mutable class attribute should usefrozenset.Using a mutable
setas a class attribute can lead to unexpected behavior if modified. Since this is a constant set of supported dtypes, usefrozensetinstead.♻️ Proposed fix
- _SUPPORTED_ACTIVATION_DTYPES = {torch.bfloat16} + _SUPPORTED_ACTIVATION_DTYPES = frozenset({torch.bfloat16})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py` at line 209, The class-level constant _SUPPORTED_ACTIVATION_DTYPES is defined as a mutable set; change it to an immutable frozenset to avoid accidental mutation. Locate the _SUPPORTED_ACTIVATION_DTYPES symbol in the Mega MoE backend module and replace its set literal with a frozenset containing the same element(s) (e.g., use frozenset(...) around torch.bfloat16) so the attribute is immutable at class scope.
182-192: Add defensive validation for hidden dimension alignment.The reshape at line 192 assumes
n % 128 == 0(sincen // 32must be divisible by 4 for theint32view). Whilecan_implementenforces this at factory time, direct construction or external calls toquantize_inputcould hit a cryptic reshape error.🛡️ Proposed fix
def _quantize_bf16_to_fp8_ue8m0(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Return (x_fp8, x_sf) in DG mega_moe's expected layout (packed int32).""" if _trtllm_mxfp8_quantize_available(): m, n = x.shape + if n % 128 != 0: + raise ValueError( + f"MegaMoE quantize_input requires hidden_size % 128 == 0 for " + f"packed-UE8M0 int32 SF layout; got n={n}" + ) # ``is_sf_swizzled_layout=False`` → flat row-major uint8 SF, one🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py` around lines 182 - 192, In _quantize_bf16_to_fp8_ue8m0 validate the input's hidden dim before reshaping: check the tensor shape (m, n) and assert that n is aligned so (n // 32) is divisible by 4 (equivalently n % 128 == 0); if not, raise a clear ValueError describing required alignment and the offending n. Place this check just after extracting m, n and before calling x_sf_u8.view(m, n // 32). This prevents the cryptic reshape/view failure when quantize_input or external callers pass misaligned tensors.
273-273: Function call in default argument creates shared instance.
ModelConfig()is evaluated once at function definition time, not per-call. IfModelConfigis mutable, this could lead to shared state issues. Consider usingNoneas default and creating the instance inside the function.♻️ Proposed fix
- model_config: ModelConfig = ModelConfig(), + model_config: ModelConfig | None = None,Then at the start of
__init__:if model_config is None: model_config = ModelConfig()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py` at line 273, The constructor currently uses a mutable default ModelConfig() which is instantiated at definition time; change the parameter default to model_config: Optional[ModelConfig] = None (or just None) in the __init__ signature and then inside __init__ (for the class that defines this constructor) add a guard like "if model_config is None: model_config = ModelConfig()" so each call gets a fresh ModelConfig instance; update any type hints/imports accordingly and ensure all references to the parameter keep the name model_config.tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (2)
635-644: Unused parameteruse_dp_padding.The
use_dp_paddingparameter is accepted but never used in the method body. If MegaMoE intentionally ignores padding (since it gets raw token counts), consider either:
- Documenting this in the docstring, or
- Removing the parameter if callers don't need to pass it
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py` around lines 635 - 644, The parameter use_dp_padding on _forward_chunk_mega_impl is never used; either remove it from the method signature and all callers (update any invocations in this module/class) to avoid dead API surface, or keep it but explicitly mark it as intentionally unused by adding a short docstring note and a sentinel usage (e.g., assign to _ = use_dp_padding or rename to _use_dp_padding) so linters/readers know MegaMoE ignores dp padding because it uses raw token counts; choose one approach and apply consistently across the class (update _forward_chunk_mega_impl signature, callers, and the method docstring).
689-691: Nit: Ambiguous Unicode character in comment.Line 691 uses
×(Unicode multiplication sign) instead ofx. While readable, this can cause issues with some tools/editors.✏️ Suggested fix
- # contract exactly (4 × buf.copy_ + kernel). + # contract exactly (4 x buf.copy_ + kernel).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py` around lines 689 - 691, Replace the Unicode multiplication sign in the comment that references the backend's "run_with_prequant" and DG's "run_fused" shape contract (the phrase "4 × buf.copy_ + kernel") with the ASCII letter "x" (change "4 × buf.copy_ + kernel" to "4 x buf.copy_ + kernel") so tools/editors won't choke on the ambiguous character; locate the comment near the mentions of run_with_prequant and run_fused in configurable_moe.py and update the comment text accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py`:
- Around line 700-711: The zero-token placeholder creates x_sf with shape (0, 0)
which violates the quantize_input contract used by _quantize_bf16_to_fp8_ue8m0
and may break downstream validation; change the x_sf creation in the else branch
of configurable_moe (next to x_fp8, topk_idx, topk_weights) to use shape (0,
self.hidden_size // 32) with dtype torch.int32 on the same device so its shape
matches expectations used by run_with_prequant and the quantization helper.
In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py`:
- Around line 351-371: create_moe currently returns MegaMoEDeepGemmFusedMoE
directly, bypassing the ConfigurableMoE path (so _forward_chunk_mega_impl and
override_quant_config are skipped); change the logic so when
ENABLE_CONFIGURABLE_MOE is active the MegaMoE backend is included in the
ConfigurableMoE routing (i.e., treat MegaMoEDeepGemmFusedMoE as one of the
classes handled by ConfigurableMoE so it flows through the same construction
path that calls _forward_chunk_mega_impl and accepts override_quant_config), and
keep the direct import/return branch only as a legacy fallback executed when
configurable mode is disabled; update create_moe, the ConfigurableMoE
dispatch/tuple, and any factory mapping that selects ConfigurableMoE to ensure
override_quant_config and _forward_chunk_mega_impl are preserved for
MegaMoEDeepGemmFusedMoE.
- Around line 104-111: The call to MegaMoEDeepGemmFusedMoE.can_implement is
using a hardcoded activation dtype and the dense intermediate_size; update it to
use the actual pretrained config values from model_config.pretrained_config (use
pretrained.torch_dtype for dtype_activation, and use
pretrained.moe_intermediate_size if present otherwise
pretrained.intermediate_size for the FFN size) while keeping hidden_size and
other flags the same so the capability check matches the later backend creation.
In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py`:
- Around line 1-14: The file backend.py has formatting differences flagged by
CI; run the ruff formatter on this file (e.g., ruff format backend.py) to apply
the project's style rules and commit the reformatted file so the SPDX header and
surrounding code match ruff-format output.
- Around line 726-728: The code indexes all_rank_num_tokens with
self.mapping.tp_rank which is always 0 in EP-only Phase 1; replace that index
with self.mapping.moe_ep_rank so each EP rank uses its correct token count
(change the statement that sets num_tokens from
all_rank_num_tokens[self.mapping.tp_rank] to use self.mapping.moe_ep_rank), and
update the nearby comment that currently references “[tp_rank]” to reference the
correct EP rank dimension (moe_ep_rank); apply the same replacement wherever the
pattern appears in configurable_moe.py (the analogous uses around num_tokens
retrieval).
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py`:
- Around line 635-644: The parameter use_dp_padding on _forward_chunk_mega_impl
is never used; either remove it from the method signature and all callers
(update any invocations in this module/class) to avoid dead API surface, or keep
it but explicitly mark it as intentionally unused by adding a short docstring
note and a sentinel usage (e.g., assign to _ = use_dp_padding or rename to
_use_dp_padding) so linters/readers know MegaMoE ignores dp padding because it
uses raw token counts; choose one approach and apply consistently across the
class (update _forward_chunk_mega_impl signature, callers, and the method
docstring).
- Around line 689-691: Replace the Unicode multiplication sign in the comment
that references the backend's "run_with_prequant" and DG's "run_fused" shape
contract (the phrase "4 × buf.copy_ + kernel") with the ASCII letter "x" (change
"4 × buf.copy_ + kernel" to "4 x buf.copy_ + kernel") so tools/editors won't
choke on the ambiguous character; locate the comment near the mentions of
run_with_prequant and run_fused in configurable_moe.py and update the comment
text accordingly.
In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py`:
- Line 209: The class-level constant _SUPPORTED_ACTIVATION_DTYPES is defined as
a mutable set; change it to an immutable frozenset to avoid accidental mutation.
Locate the _SUPPORTED_ACTIVATION_DTYPES symbol in the Mega MoE backend module
and replace its set literal with a frozenset containing the same element(s)
(e.g., use frozenset(...) around torch.bfloat16) so the attribute is immutable
at class scope.
- Around line 182-192: In _quantize_bf16_to_fp8_ue8m0 validate the input's
hidden dim before reshaping: check the tensor shape (m, n) and assert that n is
aligned so (n // 32) is divisible by 4 (equivalently n % 128 == 0); if not,
raise a clear ValueError describing required alignment and the offending n.
Place this check just after extracting m, n and before calling x_sf_u8.view(m, n
// 32). This prevents the cryptic reshape/view failure when quantize_input or
external callers pass misaligned tensors.
- Line 273: The constructor currently uses a mutable default ModelConfig() which
is instantiated at definition time; change the parameter default to
model_config: Optional[ModelConfig] = None (or just None) in the __init__
signature and then inside __init__ (for the class that defines this constructor)
add a guard like "if model_config is None: model_config = ModelConfig()" so each
call gets a fresh ModelConfig instance; update any type hints/imports
accordingly and ensure all references to the parameter keep the name
model_config.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 6ac4c50e-d778-462c-bffa-bc5bb8fae516
📒 Files selected for processing (8)
3rdparty/fetch_content.jsoncpp/tensorrt_llm/deep_gemm/CMakeLists.txtscripts/attribution/data/dependency_metadata.ymlscripts/attribution/data/files_to_dependency.ymltensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytensorrt_llm/_torch/modules/fused_moe/create_moe.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/__init__.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/backend.py
|
PR_Github #45701 [ run ] triggered by Bot. Commit: |
|
PR_Github #45701 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #45728 [ run ] triggered by Bot. Commit: |
juney-nvidia
left a comment
There was a problem hiding this comment.
Approved from oss compliance perspective.
|
PR_Github #45728 [ run ] completed with state |
6d13f9b to
edb7e59
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45945 [ run ] triggered by Bot. Commit: |
|
PR_Github #45945 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #45954 [ run ] triggered by Bot. Commit: |
…od path + tests Builds on NVIDIA#13384 (Barry's MegaMoEDeepGemmFusedMoE backend) and refactors it to share the standard ConfigurableMoE construction / weight-loading pipeline used by CutlassFusedMoE / TRTLLMGenFusedMoE. Refactor: * Move weight lifecycle (DG-native MXFP4 + UE8M0 SF tensors, checkpoint loading, scale conversion, SymmBuffer allocation, DG weight transform) out of the MegaMoE backend file and into a dedicated ``W4A8MXFP4MXFP8MegaMoEDeepGemmMethod(FusedMoEMethodBase)`` in ``tensorrt_llm/_torch/modules/fused_moe/quantization.py``. * Rename ``mega_moe/backend.py`` to ``mega_moe/mega_moe_deepgemm.py`` and shrink it to capability checks, routing/activation quantization, and the fused kernel entry point. * Wire the new method through ``ConfigurableMoE`` / ``create_moe`` so MegaMoEDeepGemm flows through the same construction and load_weights pipeline as the other backends. Fixes: * ``W4A8MXFP4MXFP8MegaMoEDeepGemmMethod.create_weights`` asserts ``hidden_size % 128 == 0`` and ``intermediate_size % 128 == 0`` up front (DG packs UE8M0 SF as int32 over 32-element blocks; misaligned configs hit a cryptic reshape downstream). * ``w3_w1_weight`` is stored as ``[w1 | w3] = [gate | up]`` (matches DG's ``_interleave_l1_weights`` and TRT-LLM's gate_proj=w1, up_proj=w3 convention; same semantic as NVIDIA#13384 commit edb7e59 applied in the new quant-method path). Tests: * Generic MoE module tests cover MegaMoEDeepGemm via ConfigurableMoE. * Pure-PyTorch QDQ reference for MegaMoEDeepGemm in ``tests/unittest/_torch/modules/moe/quantize_utils.py``. * Multi-GPU module-level coverage (TP/EP/DEP, NVLink one/two-sided) for MegaMoEDeepGemm, plus extended parametric coverage on the module side (DeepSeek-V4 / Kimi-K2 expert/hidden/intermediate Budugs). * Wire matching integration test stages for B200 / B300. Squashes prior local development commits (Wire MegaMoEDeepGemm backend path, fix MegaMoEDeepGemm bugs, Add MegaMoE generic MoE tests, Add MegaMoE DeepGEMM reference, Gate MegaMoE DeepGEMM SF alignment, Add MegaMoE module multi-GPU coverage, Use MegaMoE module reference, Extend MegaMoE module coverage x2) into a single commit. Signed-off-by: xxi <[email protected]>
…od path + tests Builds on NVIDIA#13384 (Barry's MegaMoEDeepGemmFusedMoE backend) and refactors it to share the standard ConfigurableMoE construction / weight-loading pipeline used by CutlassFusedMoE / TRTLLMGenFusedMoE. Refactor: * Move weight lifecycle (DG-native MXFP4 + UE8M0 SF tensors, checkpoint loading, scale conversion, SymmBuffer allocation, DG weight transform) out of the MegaMoE backend file and into a dedicated ``W4A8MXFP4MXFP8MegaMoEDeepGemmMethod(FusedMoEMethodBase)`` in ``tensorrt_llm/_torch/modules/fused_moe/quantization.py``. * Rename ``mega_moe/backend.py`` to ``mega_moe/mega_moe_deepgemm.py`` and shrink it to capability checks, routing/activation quantization, and the fused kernel entry point. * Wire the new method through ``ConfigurableMoE`` / ``create_moe`` so MegaMoEDeepGemm flows through the same construction and load_weights pipeline as the other backends. Fixes: * ``W4A8MXFP4MXFP8MegaMoEDeepGemmMethod.create_weights`` asserts ``hidden_size % 128 == 0`` and ``intermediate_size % 128 == 0`` up front (DG packs UE8M0 SF as int32 over 32-element blocks; misaligned configs hit a cryptic reshape downstream). * ``w3_w1_weight`` is stored as ``[w1 | w3] = [gate | up]`` (matches DG's ``_interleave_l1_weights`` and TRT-LLM's gate_proj=w1, up_proj=w3 convention; same semantic as NVIDIA#13384 commit edb7e59 applied in the new quant-method path). Tests: * Generic MoE module tests cover MegaMoEDeepGemm via ConfigurableMoE. * Pure-PyTorch QDQ reference for MegaMoEDeepGemm in ``tests/unittest/_torch/modules/moe/quantize_utils.py``. * Multi-GPU module-level coverage (TP/EP/DEP, NVLink one/two-sided) for MegaMoEDeepGemm, plus extended parametric coverage on the module side (DeepSeek-V4 / Kimi-K2 expert/hidden/intermediate Budugs). * Wire matching integration test stages for B200 / B300. Squashes prior local development commits (Wire MegaMoEDeepGemm backend path, fix MegaMoEDeepGemm bugs, Add MegaMoE generic MoE tests, Add MegaMoE DeepGEMM reference, Gate MegaMoE DeepGEMM SF alignment, Add MegaMoE module multi-GPU coverage, Use MegaMoE module reference, Extend MegaMoE module coverage x2) into a single commit. Signed-off-by: xxi <[email protected]>
|
PR_Github #45954 [ run ] completed with state |
660acf2 to
e76ddc4
Compare
Update scripts/attribution/data/dependency_metadata.yml and files_to_dependency.yml to reflect the deepgemm upgrade in 3rdparty/fetch_content.json (4ff3f54d... -> c491439e...). Aligned with the upstream attribution refresh for the same bump (PR NVIDIA#13384) so this branch only carries the deepgemm-related entries; no cutlass / cuda / nccl / torch entries are introduced and no new cas/ blobs are added. Signed-off-by: Fanrong Li <[email protected]>
Update scripts/attribution/data/dependency_metadata.yml and files_to_dependency.yml to reflect the deepgemm upgrade in 3rdparty/fetch_content.json (4ff3f54d... -> c491439e...). Aligned with the upstream attribution refresh for the same bump (PR NVIDIA#13384) so this branch only carries the deepgemm-related entries; no cutlass / cuda / nccl / torch entries are introduced and no new cas/ blobs are added. Signed-off-by: Fanrong Li <[email protected]>
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
|
PR_Github #46755 [ run ] triggered by Bot. Commit: |
|
PR_Github #46755 [ run ] completed with state
|
Update scripts/attribution/data/dependency_metadata.yml and files_to_dependency.yml to reflect the deepgemm upgrade in 3rdparty/fetch_content.json (4ff3f54d... -> c491439e...). Aligned with the upstream attribution refresh for the same bump (PR NVIDIA#13384) so this branch only carries the deepgemm-related entries; no cutlass / cuda / nccl / torch entries are introduced and no new cas/ blobs are added. Signed-off-by: Fanrong Li <[email protected]>
Update scripts/attribution/data/dependency_metadata.yml and files_to_dependency.yml to reflect the deepgemm upgrade in 3rdparty/fetch_content.json (4ff3f54d... -> c491439e...). Aligned with the upstream attribution refresh for the same bump (PR NVIDIA#13384) so this branch only carries the deepgemm-related entries; no cutlass / cuda / nccl / torch entries are introduced and no new cas/ blobs are added. Signed-off-by: Fanrong Li <[email protected]>
Update scripts/attribution/data/dependency_metadata.yml and files_to_dependency.yml to reflect the deepgemm upgrade in 3rdparty/fetch_content.json (4ff3f54d... -> c491439e...). Aligned with the upstream attribution refresh for the same bump (PR NVIDIA#13384) so this branch only carries the deepgemm-related entries; no cutlass / cuda / nccl / torch entries are introduced and no new cas/ blobs are added. Signed-off-by: Fanrong Li <[email protected]>
Introduces ``MegaMoEFusedMoE``, a new MoE backend wrapping DeepGEMM's fused ``fp8_fp4_mega_moe`` kernel (dispatch + GEMM1 + SwiGLU + GEMM2 + combine into a single launch via NVLink SymmBuffer). Accepts the same W4A8_MXFP4_MXFP8 weight layout as ``TRTLLMGenFusedMoE`` so VANILLA / FUSED_GATE_UP_PROJ loaders work unchanged. Integration follows the ConfigurableMoE pattern per @xingfei: - ``fused_moe/mega_moe/__init__.py`` / ``backend.py``: new ``MegaMoEFusedMoE`` that (a) creates DG-native uint8 MXFP4 + UE8M0 weight tensors, (b) resolves the EP ProcessGroup at construction (no collective at forward time), (c) allocates the DG ``SymmBuffer`` via a process-level cache on ``post_load_weights``, and (d) dispatches ``deep_gemm.fp8_fp4_mega_moe`` in ``forward_impl``. - ``fused_moe/create_moe.py``: add ``MEGAMOE`` backend type; ``get_moe_cls`` routes W4A8_MXFP4_MXFP8 to ``MegaMoEFusedMoE`` and falls back to ``CutlassFusedMoE`` for every other quant (mirrors the TRTLLM / CUTEDSL fallback pattern). ``create_moe_backend`` gets a ``MegaMoEFusedMoE`` constructor branch. - ``fused_moe/configurable_moe.py``: add a fast-path guard at the top of ``_forward_chunk_impl`` and a dedicated ``_forward_chunk_mega_impl`` that forwards the ADP shape contract (``all_rank_num_tokens``, ``use_dp_padding``) directly to ``backend.forward_impl`` and skips the EPLB / quant / combine orchestration. Unit tests in ``tests/unittest/_torch/modules/moe/test_mega_moe.py`` cover: * ``can_implement`` capability matrix — accepts ``W4A8_MXFP4_MXFP8 + bfloat16 + SM100``, rejects every other Budug with a descriptive reason. * ``get_moe_cls("MEGAMOE")`` dispatch — returns ``MegaMoEFusedMoE`` for the supported quant, falls back to ``CutlassFusedMoE`` for anything else (and on non-SM100 / no-DG runners so CI on non-Blackwell machines passes without skipping). * ``apply_router_weight_on_input`` rejection at construction time (the fused kernel applies routing weights on the MoE output, not the input — the two paths are not equivalent under SwiGLU). * ADP topology guard: ``use_dp and parallel_size > 1`` requires ``ep_size == parallel_size``. * Weight-loader shape contract for both ``VANILLA`` and ``FUSED_GATE_UP_PROJ`` loading modes — verifies the expected MXFP4/UE8M0 tensor shapes are produced after loading. Hot-path validation is left to the multi-GPU harness under ``tmp_test_scripts`` (requires 4+ GPUs + bundled DeepGEMM with ``fp8_fp4_mega_moe``); unit tests skip cleanly when those prerequisites are missing. Phase 1 constraints: EPLB disabled (DG's SymmBuffer dispatch is incompatible with ``prepare_dispatch`` / NVLinkTwoSided — will revisit in a follow-up), moe_tp_size=1 (EP-only), shapes must be divisible by 128 (packed-UE8M0 SF int32 stride). Signed-off-by: Barry Kang <[email protected]>
Collapses the pre-kernel overhead in the MegaMoE path from ~460 us down
to ~85 us (DSV3 ep=4, 4 × GB200, uniform routing) via two complementary
changes:
1. **Hoist routing + BF16→FP8 quant out of the backend.** The backend
used to own its own ``routing_method.apply`` + ``per_token_cast_to_fp8``
+ buffer copies + kernel launch inside ``forward_impl``. That buried
the pre-processing behind an extra two Python frames (ConfigurableMoE
→ _forward_chunk_impl → forward_impl → backend.forward_impl) and
recomputed what the outer pipeline already knows how to do for the
CUTLASS / CUTEDSL backends.
``_forward_chunk_mega_impl`` now mirrors the standard separated-
routing contract: it slices ``x`` / ``router_logits`` to the unpadded
ADP count, runs ``self.routing_method.apply`` once, runs
``self.backend.quantize_input(x_real)`` once, then calls a new
kernel-only backend entry
``run_with_prequant(x_fp8, x_sf, topk_idx, topk_weights, num_tokens,
output_dtype)`` that just does the 4 × ``buf.copy_()`` + fused kernel
— matching DG's own ``run_fused`` shape contract, so the GPU work
inside the backend call is now what DG's benchmarks measure.
Zero-token ranks still enter ``run_with_prequant`` with fabricated
empty tensors so the SymmBuffer collective doesn't hang peers.
The existing ``backend.forward_impl`` path stays as a stand-alone
fallback and now also routes through ``self.quantize_input``.
2. **Use TRT-LLM's C++ MXFP8 quant kernel.** DG's Python
``per_token_cast_to_fp8(..., gran_k=32, use_packed_ue8m0=True)``
decomposes into ~8 elementwise/reduction ops (empty+fill, copy, abs,
cast, amax, clamp, div, ue8m0 round, mul, cast to fp8, pack int32).
Wrapping it with ``torch.compile`` fuses these to ~1-2 Triton kernels
and gets to ~60-260 us depending on seq.
``torch.ops.trtllm.mxfp8_quantize(x, False, alignment=32)`` is an
existing C++ CUDA kernel in this tree (``thop/mxFp8Quantize.cpp``)
— CUTLASS's MoE path already uses it. Byte-identical output to DG's
helper (roundtrip-verified on random BF16: fp8 bytes and SF int32
both ``torch.equal`` after reshape), and 5-25× faster — consistently
~11 us independent of seq.
``_quantize_bf16_to_fp8_ue8m0`` prefers the TRT-LLM op when
``libth_common.so`` has registered it (always the case inside
ConfigurableMoE because ``create_moe.py`` imports CutlassFusedMoE
at module top, which triggers the library load). Falls back to the
``torch.compile``'d DG helper for slim builds.
Perf deltas (us/iter, DSV3 E=256 k=8 H=7168 I=2048, ep=4):
seq initial backend +hoist +C++ quant final vs CUTLASS
--- --------------- ------ ---------- ----------------
1 895 527 391 0.70×
32 890 454 369 0.74×
128 863 453 432 0.87×
512 866 462 437 0.88×
2048 1426 1007 939 1.88×
MegaMoE now beats CUTLASS by 10-30% across seq 1-512 (uniform
routing). Large-seq (≥ 1024) remains limited by DG kernel scaling
itself — in-kernel tuning, not something the wrapper can address.
Unit tests (``test_mega_moe.py``, 15 cases) remain green.
Signed-off-by: Barry Kang <[email protected]>
Signed-off-by: Barry Kang <[email protected]>
R1: configurable_moe._forward_chunk_mega_impl — zero-token x_sf placeholder used shape (0, 0) which violates the contract returned by quantize_input (packed-UE8M0 int32 over 32-element blocks, 4 u8 per int32). Use (0, hidden_size // 128) int32 to match. R2: create_moe — MegaMoEDeepGemmFusedMoE was bypassing ConfigurableMoE because it wasn't in the supported tuple, leaving ConfigurableMoE._forward_chunk_mega_impl (the perf hoist from 5112801) unreachable. Add it lazily so non-MegaMoE callers don't import DeepGEMM at module load. R3: get_moe_cls — capability check for MegaMoE used hardcoded torch.bfloat16 and dense intermediate_size; resolve from pretrained_config (torch_dtype + moe_intermediate_size first) so can_implement matches the values used at construction. R5: backend.forward_impl + ConfigurableMoE._forward_chunk_mega_impl indexed all_rank_num_tokens with mapping.tp_rank, which is the outer TP rank rather than the EP rank. Phase 1 asserts ep_size == parallel_size, so use mapping.moe_ep_rank for clarity and topology robustness. Nits: * drop unused use_dp_padding from _forward_chunk_mega_impl signature (and caller) — MegaMoE uses raw token counts, not DP padding. * replace Unicode "x" multiplication sign with ASCII "x" in two comments/docstrings (configurable_moe.py + backend.py). * _SUPPORTED_ACTIVATION_DTYPES set -> frozenset. * _quantize_bf16_to_fp8_ue8m0 raises ValueError early when n % 128 != 0 instead of failing at the int32 reshape. * reshape five docstrings in mega_moe/backend.py to satisfy ruff D205 (blank line between summary and description). ruff format and ruff check now pass with the project config. Signed-off-by: Barry Kang <[email protected]>
Pre-commit's yapf hook prefers wrapping the single-element tuple extension across two lines. No semantic change. Signed-off-by: Barry Kang <[email protected]>
DeepGEMM's ``fp8_fp4_mega_moe`` kernel interprets the first half of the L1 weight tensor as the SwiGLU gate side and the second half as the up side (deep_gemm/mega/__init__.py:78 ``_interleave_l1_weights``: ``gate = t[:, :half]; up = t[:, half:]``). TRT-LLM's MoE convention -- consistent across ``modeling_gpt_oss.py:743-746`` and the ``FUSED_GATE_UP_PROJ`` loader at ``quantization.py:362-365`` -- maps ``w1 = gate_proj`` and ``w3 = up_proj`` (HF's ``gate_up_proj`` is laid out as ``[gate | up]`` along the output dim, ``chunk(2)[0]`` -> w1, ``chunk(2)[1]`` -> w3). The previous ``cat([w3, w1])`` order silently swapped which side the ``silu`` was applied to, computing ``silu(up_proj @ x) * (gate_proj @ x)`` instead of ``silu(gate_proj @ x) * (up_proj @ x)``. Verified against a pure-PyTorch QDQ reference (no DG / TRT-LLM kernel ops): with this fix, MegaMoE output matches the reference bit-exact. Without it, the per-element mismatch rate is ~94% vs the reference. Signed-off-by: Barry Kang <[email protected]>
994b5a3 to
e7d80e9
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #47186 [ run ] triggered by Bot. Commit: |
|
/bot kill |
|
PR_Github #47243 [ kill ] triggered by Bot. Commit: |
|
PR_Github #47243 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #47245 [ run ] triggered by Bot. Commit: |
|
PR_Github #47245 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47302 [ run ] triggered by Bot. Commit: |
|
PR_Github #47302 [ run ] completed with state |
…8_fp4_mega_moe (NVIDIA#13384) Signed-off-by: Barry Kang <[email protected]>
This PR enables the mega-MoE-kernel from DeepGEMM and added related backend into
ConfigurableMoE.