如何在PyTorch中实现模型参数的动态缩放并对缩放系数计算梯度?-灵析社区

满脑子智慧溢出

我希望将模型A的参数乘以一个标量lambda得到模型B(即模型B的架构和模型A的一样但每个参数都是模型A的lambda倍),现在我希望将一个tensor输入模型B,并对输出进行反向传播,然后优化参数lambda。但梯度无法反传到lambda上。具体代码如下 import torch import torch.nn as nn class MyBaseModel(nn.Module): def __init__(self): super(MyBaseModel, self).__init__() self.linear1 = nn.Linear(3, 8) self.act1 = nn.ReLU() self.linear2 = nn.Linear(8, 4) self.act2 = nn.Sigmoid() self.linear3 = nn.Linear(4, 5) def forward(self, x): return self.linear3(self.act2(self.linear2(self.act1(self.linear1(x))))) class WeightedSumModel(nn.Module): def __init__(self): super(WeightedSumModel, self).__init__() self.lambda_ = nn.Parameter(torch.tensor(2.0)) self.a = MyBaseModel() self.b = MyBaseModel() def forward(self, x): for para_b, para_a in zip(self.a.parameters(), self.b.parameters()): para_b.data = para_a.data * self.lambda_ return self.b(x).sum() input_tensor = torch.ones((2, 3)) weighted_sum_model = WeightedSumModel() output_tensor = weighted_sum_model(input_tensor) output_tensor.backward() print(weighted_sum_model.lambda_.grad) 输出为None,现在我希望有代码能实现相同的功能并能成功计算lambda的梯度。 用tensorboard可视化计算图发现weighted_sum_model的参数中只有b出现在了计算图上,a和lambda_都没有出现在计算图中。

阅读量:137

点赞量:0

问AI
import torch import torch.nn as nn class MyBaseModel(nn.Module): def __init__(self): super(MyBaseModel, self).__init__() self.linear1 = nn.Linear(3, 8) self.act1 = nn.ReLU() self.linear2 = nn.Linear(8, 4) self.act2 = nn.Sigmoid() self.linear3 = nn.Linear(4, 5) def forward(self, x): return self.linear3(self.act2(self.linear2(self.act1(self.linear1(x))))) class WeightedSumModel(nn.Module): def __init__(self): super(WeightedSumModel, self).__init__() self.lambda_ = nn.Parameter(torch.tensor(2.0)) self.a = MyBaseModel() def forward(self, x): x = x for layer in self.a.children(): if isinstance(layer, nn.Linear): weight = layer.weight * self.lambda_ bias = layer.bias * self.lambda_ if layer.bias is not None else None x = torch.nn.functional.linear(x, weight, bias) elif isinstance(layer, nn.ReLU) or isinstance(layer, nn.Sigmoid): x = layer(x) return x input_tensor = torch.ones((2, 3)) weighted_sum_model = WeightedSumModel() output_tensor = weighted_sum_model(input_tensor) output_tensor.sum().backward() print(weighted_sum_model.lambda_.grad) # 应该不再是None