您现在所在的位置:> 主页 > 新闻中心 > 通知公告
【PyTorch】优化器 torch.optim.Optimizer

之前写过一篇 TensorFlow 的优化器 AdamOptimizer 的源码解读,这次来看一看 PyTorch 的优化器源码。

药师:【TensorFlow】优化器AdamOptimizer的源码分析

PyTorch 的优化器基本都继承于 "class Optimizer",这是所有 optimizer 的 base class,本文尝试对其中的源码进行解读。


总的来说,PyTorch 中 Optimizer 的代码相较于 TensorFlow 要更易读一些。下边先通过一个简单的例子看一下,PyTorch 中是如何使用优化器的。

Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

首先,在创建优化器对象的时候,要传入网络模型的参数,并设置学习率等优化方法的参数。然后使用函数zero_grad将梯度置为零。接着调用函数backward来进行反向传播计算梯度。最后使用优化器的step函数来更新参数。

以 PyTorch 中的 SGD Optimizer 为例,下边是 __init__ 函数。 网络模型的参数被传进来后,用params表示;其余参数被打包进字典中命名为defaults。再通过super(SGD, self).__init__(params, defaults)来将paramsdefaults传给父类Optimizer__init__函数。

def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

这样做的好处就是,我可以把子类中一些相同的处理操作集中写到父类的初始化函数中去,这样所有子类只需要调用就好了。例如 SGD 类的其他函数中所用到的self.param_groups 就是在父类的__init__函数中创建的。

 def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        self.defaults = defaults

        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

其中第十行中的defaultdict的作用在于当字典里的 key 被查找但不存在时,返回的不是keyError而是一个默认值,此处defaultdict(dict)返回的默认值会是个空字典。最后一行调用的self.add_param_group(param_group),其中param_group是个字典,Key 就是params,Value 就是param_groups=list(params)

函数add_param_group的主要作用是将param_group放进self.param_groups中。首先要将网络参数重新组织放到列表中param_group['params']=list(params)。接着将self.defaults中的键值对遍历放到字典param_group中。之后对self.param_groupsparam_group中的元素进行判断,确保没有重复的参数。最后将字典param_group放进列表self.param_groups中。( 注:self.param_groups=[]是在__init__函数中创建的 )

def add_param_group(self, param_group):
        r"""Add a param group to the :class:`Optimizer` s `param_groups`.

        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`Optimizer` as training progresses.

        Arguments:
            param_group (dict): Specifies what Tensors should be optimized along with group
            specific optimization options.
        """
        assert isinstance(param_group, dict), "param group must be a dict"

        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer parameters need to be organized in '
                            'ordered collections, but the ordering of tensors in sets '
                            'will change between runs. Please use a list instead.')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer can only optimize Tensors, "
                                "but one of the params is " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")

        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required "
                                 "optimization parameter " + name)
            else:
                param_group.setdefault(name, default)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group")

        self.param_groups.append(param_group) 

接下来看一下函数zero_grad。优化器 SGD 中的 zero_grad 函数如下所示。 可以看到,操作很简单,就是将所有参数的梯度置为零p.grad.zero_()detach_()的作用是Detaches the Tensor from the graph that created it, making it a leaf. self.param_groups是列表,其中的元素是字典。

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

Optimizer 更新参数主要是靠 step 函数,在父类 Optimizerstep 函数中只有一行代码raise NotImplementedError 。SGD 中的实现如下所示。正如前边介绍的,网络模型参数和优化器的参数都保存在列表 self.param_groups 的元素中,该元素以字典形式存储和访问具体的网络模型参数和优化器的参数。所以,可以通过两层循环访问网络模型的每一个参数 p 。获取到梯度d_p=p.grad.data之后,根据优化器参数设置是否使用 momentum或者nesterov再对参数进行调整。最后一行 p.data.add_(-group['lr'], d_p)的作用是对参数进行更新。

def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        loss = closure()

    for group in self.param_groups:
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        for p in group['params']:
            if p.grad is None:
                continue
            d_p = p.grad.data
            if weight_decay != 0:
                d_p.add_(weight_decay, p.data)
            if momentum != 0:
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                else:
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(1 - dampening, d_p)
                if nesterov:
                    d_p = d_p.add(momentum, buf)
                else:
                    d_p = buf

            p.data.add_(-group['lr'], d_p)

    return loss

PyTorch 中的 Adam Optimizer 和 SGD Optimizer 的主要区别也是 step 函数不同。Adam Optimizer 中的 step 函数如下所示。其中,对于每个网络模型参数都使用state['exp_avg']state['exp_avg_sq']来保存 梯度 和 梯度的平方 的移动平均值。第一次更新的时候没有state,即len(state)==0,所以两个数值都需要使用torch.zeros_like(p.data)来初始化为 0 ,之后每次都只需要从state中取出该值使用和更新即可。state['step']用于保存本次更新是优化器第几轮迭代更新参数。最后使用p.data.addcdiv_(-step_size, exp_avg, denom)来更新网络模型参数 p

def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        loss = closure()

    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError('Adam does not support sparse gradients, '
                                    'please consider SparseAdam instead')
            amsgrad = group['amsgrad']

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state['step'] = 0
                # Exponential moving average of gradient values
                state['exp_avg'] = torch.zeros_like(p.data)
                # Exponential moving average of squared gradient values
                state['exp_avg_sq'] = torch.zeros_like(p.data)
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state['max_exp_avg_sq'] = torch.zeros_like(p.data)

            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
            if amsgrad:
                max_exp_avg_sq = state['max_exp_avg_sq']
            beta1, beta2 = group['betas']

            state['step'] += 1
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            if group['weight_decay'] != 0:
                grad.add_(group['weight_decay'], p.data)

            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(1 - beta1, grad)
            exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                # Use the max. for normalizing running avg. of gradient
                denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
            else:
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

            step_size = group['lr'] / bias_correction1

            p.data.addcdiv_(-step_size, exp_avg, denom)

    return loss

值得注意的是,Adam Optimizer 只能处理 dense gradient,要想处理 sparse gradient 需要使用 SparseAdam Optimizer


另外,我收集了一些 PyTorch 实现的 Optimizer,欢迎大家来一起维护。

github.com/201419/Optim

平台注册入口