训练小参数LLM将优化器从AdamW换成Muon的尝试

2026年5月31日 8点热度 0人点赞 0条评论

最近社区都在说 Muon 用在 LLM 上的训练效果要比 AdamW 好很多,这里根据 Kimi 的论文(http://arxiv.org/abs/2502.16982)和仓库做了一些尝试。

选用模型:minimind,Github 链接:https://github.com/jingyaogong/minimind

硬件:AutoDL自己租个 nv 的卡就行

常见的 AdamW 优化器就是在 Adam 的基础上在梯度更新时加上梯度衰减,这样的话可以避免更新的时候产生更大的参数。

而 Muon 并没有走 Adam 的那条路,而是在 Momentum 动量的基础上进行改进,将二维的参数使用 Newton-Schulz iteration 进行更新,其余的部分仍然用 AdamW。

再多了就是一堆社区做过的论文工作了,所以这里放上仓库改好的代码:

import math

import torch
import torch.optim


def _newtonschulz5(x, steps):
    """5th-order Newton-Schulz orthogonalization for 2D matrices.

    From Moonlight / KellerJordan: returns US'V^T (not exact UV^T) where
    S'_{ii} ~ Uniform(0.5, 1.5). This approximation does not hurt
    performance and is cheaper to compute.
    """
    assert x.ndim == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = x.bfloat16()
    transposed = X.size(0) > X.size(1)
    if transposed:
        X = X.T
    X = X / (X.norm() + 1e-7)
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * A @ A
        X = a * X + B @ X
    if transposed:
        X = X.T
    return X.float()


_zeropower_fn = _newtonschulz5
if torch.cuda.is_available():
    try:
        _compiled = torch.compile(_newtonschulz5, fullgraph=True)
        _compiled(torch.randn(4, 4, device="cuda"), 1)
        _zeropower_fn = _compiled
    except Exception:
        pass


def zeropower_via_newtonschulz5(G, steps):
    return _zeropower_fn(G, steps)


def _adjust_lr_for_muon(lr, param_shape):
    """Scale learning rate by sqrt of the larger matrix dimension.

    From the Moonlight paper: adjusted_lr = lr * 0.2 * sqrt(max(rows, cols)).
    """
    return lr * 0.2 * math.sqrt(max(param_shape[:2]))


class Muon(torch.optim.Optimizer):
    """Muon — MomentUm Orthogonalized by Newton-Schulz.

    2D weight matrices receive an orthogonalized momentum update; all other
    parameters (biases, norms, embeddings, lm_head) fall back to AdamW.

    Reference: https://github.com/MoonshotAI/Moonlight
    """

    def __init__(
        self,
        params,
        lr=1e-3,
        weight_decay=0.1,
        momentum=0.95,
        nesterov=True,
        ns_steps=5,
        adamw_betas=(0.9, 0.95),
        adamw_eps=1e-8,
    ):
        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            momentum=momentum,
            nesterov=nesterov,
            ns_steps=ns_steps,
            adamw_betas=adamw_betas,
            adamw_eps=adamw_eps,
        )
        super().__init__(params, defaults)
        self._muon_param_ids = set()
        self._adamw_param_ids = set()

    def mark_muon_params(self, muon_params):
        for p in muon_params:
            self._muon_param_ids.add(id(p))
            self.state[p]["use_muon"] = True

    def mark_adamw_params(self, adamw_params):
        for p in adamw_params:
            self._adamw_param_ids.add(id(p))
            self.state[p]["use_muon"] = False

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            wd = group["weight_decay"]
            momentum = group["momentum"]
            nesterov = group["nesterov"]
            ns_steps = group["ns_steps"]
            betas = group["adamw_betas"]
            eps = group["adamw_eps"]

            muon_params = []
            adamw_params = []
            for p in group["params"]:
                if p.grad is None:
                    continue
                if self.state[p].get("use_muon", False):
                    muon_params.append(p)
                else:
                    adamw_params.append(p)

            # --- Muon update ---
            for p in muon_params:
                g = p.grad
                if g.ndim > 2:
                    g = g.view(g.size(0), -1)
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf = state["momentum_buffer"]
                buf.mul_(momentum).add_(g)
                g = g.add(buf, alpha=momentum) if nesterov else buf
                u = zeropower_via_newtonschulz5(g, ns_steps)
                adjusted_lr = _adjust_lr_for_muon(lr, p.shape)
                p.mul_(1 - lr * wd)
                p.add_(u.to(dtype=p.dtype), alpha=-adjusted_lr)

            # --- AdamW update ---
            for p in adamw_params:
                g = p.grad
                state = self.state[p]
                if "step" not in state:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p)
                    state["exp_avg_sq"] = torch.zeros_like(p)
                state["step"] += 1
                step = state["step"]
                exp_avg = state["exp_avg"]
                exp_avg_sq = state["exp_avg_sq"]
                bias1 = 1 - betas[0] ** step
                bias2 = 1 - betas[1] ** step
                exp_avg.lerp_(g, 1 - betas[0])
                exp_avg_sq.lerp_(g.square(), 1 - betas[1])
                denom = exp_avg_sq.sqrt().div_(math.sqrt(bias2)).add_(eps)
                step_size = lr / bias1
                p.addcdiv_(exp_avg, denom, value=-step_size)
                if wd != 0:
                    p.mul_(1 - lr * wd)

        return loss


def build_optimizer(model, args):
    wd = getattr(args, "weight_decay", 0.1)
    if args.optimizer == "adamw":
        return torch.optim.AdamW(
            model.parameters(), lr=args.learning_rate, weight_decay=wd
        )

    if args.optimizer != "muon":
        raise ValueError(f"Unknown optimizer: {args.optimizer}")

    muon_params = []
    adamw_params = []
    for name, param in model.named_parameters():
        if param.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name:
            muon_params.append(param)
        else:
            adamw_params.append(param)

    all_params = muon_params + adamw_params
    opt = Muon(
        all_params,
        lr=args.learning_rate,
        weight_decay=wd,
        momentum=getattr(args, "muon_momentum", 0.95),
        nesterov=True,
        ns_steps=getattr(args, "muon_ns_steps", 5),
    )
    opt.mark_muon_params(muon_params)
    opt.mark_adamw_params(adamw_params)
    return opt

从代码中也能明显看出来,用 AdamW 的时候每次参数更新需要保存两个状态,而用 Muon 只需要保存一个就可以,能节省不少的显存占用。

这里放一下 AdamW 和 Muon 在 500 个 step 迭代之后的 Loss:

在 Pretrain 阶段时训练速度要比 AdamW 慢一点,但是到 SFT 阶段时可以根据情况做一个早停,这样的话训练速度会快一些。

MuWinds

这个人很懒,什么都没留下

文章评论