Entertainment
Schedule-Level Shared-Prefix Reuse for LLM RL Training
Key Points
Announce Type: replace Abstract: GRPO-based LLM post-training commonly samples multiple trajectories from the same prompt and then trains on the resulting group. In long-context GRPO workloads, this shared prompt-side prefix can contain retrieved passages, visual tokens, tool schemas, system instructions, or task context, while the full rollout group is still too large to pack into one training microbatch. Standard dense trainers therefore recompute the same prefix forward and backward for...
arXiv:2606.01143v3 Announce Type: replace
Abstract: GRPO-based LLM post-training commonly samples multiple trajectories from the same prompt and then trains on the resulting group. In long-context GRPO workloads, this shared prompt-side prefix can contain retrieved passages, visual tokens, tool schemas, system instructions, or task context, while the full rollout group is still too large to pack into one training microbatch. Standard dense trainers therefore recompute the same prefix forward and backward for every trajectory. We present a schedule-level reuse mechanism that decouples prefix and suffix computation. The schedule runs prefix forward once, executes suffixes as ordinary microbatches while reading prefix K/V and accumulating prefix-side gK/gV , and then runs prefix backward once on the accumulated gradient cache. This reordered schedule is equivalent to baseline training over real arithmetic and aligns numerically within finite-precision tolerance. Because only K/V and gK/gV are hot during suffix computation, the approach offloads dormant prefix activations, integrates with TP/EP/CP/PP and DP-style placement at the execution level, and preserves aux-loss-based MoE router semantics through logical prefix-token accounting. On dense Llama3-8B, Qwen3-8B, and MoE Qwen3-MoE-30B-A3B configurations, the schedule matches optimizer updates across TP/CP/PP/EP combinations, aligns on a 100-step real GRPO actor-update trace replay, reaches up to 4.395x speedup (2.930x under a conservative compile-on comparison) as prefix ratio and GRPO group size grow, and reduces Phase-B peak HBM by up to 59.1%, extending the Llama3-8B capacity frontier from 17,920 to 29,696 total tokens.