mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
Failing CUDA tests fixes (#4682)
This commit is contained in:
parent
263bfd1e93
commit
2bc6e647c7
@ -71,7 +71,7 @@ def test_predict_sam():
|
||||
predictor = SAMPredictor(overrides=overrides)
|
||||
|
||||
# Set image
|
||||
predictor.set_image('ultralytics/assets/zidane.jpg') # set with image file
|
||||
predictor.set_image(ASSETS / 'zidane.jpg') # set with image file
|
||||
# predictor(bboxes=[439, 437, 524, 709])
|
||||
# predictor(points=[900, 370], labels=[1])
|
||||
|
||||
|
@ -140,11 +140,11 @@ def test_track_stream():
|
||||
|
||||
# Test Global Motion Compensation (GMC) methods
|
||||
for gmc in 'orb', 'sift', 'ecc':
|
||||
with open(ROOT / 'cfg/trackers/botsort.yaml') as f:
|
||||
with open(ROOT / 'cfg/trackers/botsort.yaml', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
tracker = TMP / f'botsort-{gmc}.yaml'
|
||||
data['gmc_method'] = gmc
|
||||
with open(tracker, 'w') as f:
|
||||
with open(tracker, 'w', encoding='utf-8') as f:
|
||||
yaml.safe_dump(data, f)
|
||||
model.track('https://ultralytics.com/assets/decelera_portrait_min.mov', imgsz=160, tracker=tracker)
|
||||
|
||||
@ -166,7 +166,7 @@ def test_train_pretrained():
|
||||
|
||||
|
||||
def test_export_torchscript():
|
||||
f = YOLO(MODEL).export(format='torchscript', optimize=True)
|
||||
f = YOLO(MODEL).export(format='torchscript', optimize=False)
|
||||
YOLO(f)(SOURCE) # exported model inference
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ class FastSAMPredictor(DetectionPredictor):
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
full_box = torch.zeros(p[0].shape[1])
|
||||
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
||||
|
@ -311,7 +311,7 @@ def yaml_save(file='data.yaml', data=None, header=''):
|
||||
data[k] = str(v)
|
||||
|
||||
# Dump data to file in YAML format
|
||||
with open(file, 'w') as f:
|
||||
with open(file, 'w', errors='ignore', encoding='utf-8') as f:
|
||||
if header:
|
||||
f.write(header)
|
||||
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user