prediction augment (TTA) fix (#288)

This commit is contained in:
leonnil 2024-06-23 08:21:40 +00:00
parent 2c36ab0f10
commit aad320dd80

View File

@ -324,7 +324,11 @@ class DetectionModel(BaseModel):
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = super().predict(xi)[0] # forward
yi = super().predict(xi) # forward
if isinstance(yi, dict):
yi = yi["one2one"] # yolov10 outputs
if isinstance(yi, (list, tuple)):
yi = yi[0]
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
y = self._clip_augmented(y) # clip augmented tails