Fix Conv2() fusing (#2885)

This commit is contained in:
Glenn Jocher 2023-05-29 02:08:45 +02:00 committed by GitHub
parent 882dbe62ad
commit a2bb42dfe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,7 +59,7 @@ class Conv2(Conv):
"""Fuse parallel convolutions.""" """Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data) w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]] i = [x // 2 for x in w.shape[2:]]
w[:, :, i[0] - 1:i[0], i[1] - 1:i[1]] = self.cv2.weight.data.clone() w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
self.conv.weight.data += w self.conv.weight.data += w
self.__delattr__('cv2') self.__delattr__('cv2')