我的训练循环中出现以下错误,我不太明白问题出在哪里。我目前正在编写这段代码,所以还没有最终确定,但我无法弄清楚这个问题是什么。
我尝试用谷歌搜索错误并阅读了一些答案,但似乎仍然无法理解问题的症结所在。
数据集和数据加载器 (X和Y已经给我了,都是[2000, 40, 1]张量)
class TrainingDataset(data.Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return Nf
# returns corresponding input/output pairs
def __getitem__(self, t):
X = self.X[t]
y = self.y[t]
#print(X.shape, y.shape)
return X, y
# prints torch.Size([2000, 40, 1]) torch.Size([2000, 40, 1])
print(x.size(), y.size())
dataset = TrainingDataset(x,y)
batchSize = 20
dataIter = data.DataLoader(dataset, batchSize)
型号:
class Encoder(nn.Module):
def __init__(self, num_inputs = 40, num_outputs = 40):
super(Encoder, self).__init__()
self.num_inputs = num_inputs
self.num_hidden = num_hidden
self.num_outputs = num_outputs
self.layers = nn.Sequential(
nn.Linear(num_inputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs)
)
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
训练循环:
for epoch in range(epochs):
for batch in dataIter:
optimiser.zero_grad()
l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
l.backward()
optimiser.step()
错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-aa1c60616d82> in <module>()
6 for batch in dataIter:
7 optimiser.zero_grad()
----> 8 l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
9 l.backward()
10 optimiser.step()
2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
TypeError: forward() takes 2 positional arguments but 3 were given
谁能指出我正确的方向?我刚刚开始学习和使用 pytorch,所以我还不擅长这些。
最佳答案
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
你的错误就在这里,除了self
之外,这个函数应该只有一个参数。
关于python - 类型错误 : forward() takes 2 positional arguments but 3 were given in pytorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67039926/