Yolov8物体关键点(文具、螺栓)
[toc]
1. 视频/截图
https://www.bilibili.com/video/BV1Pa4y1f7kL/?spm_id_from=888.80997.embed_other.whitelist&t=1&vd_source=39b1662212679b11469d17d3bee8df4e
2. 核心代码
# 测试视频
from ultralytics import YOLO
import cv2
import numpy as np
import sys
import time
# 读取命令行参数
weight_path = sys.argv[1]
media_path = sys.argv[2]
need_record = sys.argv[3]
need_display = sys.argv[4]
thresh = sys.argv[5] if len(sys.argv) >= 6 else 0.3
# to float
thresh = float(thresh)
# 加载模型
model = YOLO(weight_path )
# 获取类别
objs_labels = model.names # get class labels
print(objs_labels)
# 读取视频
cap = cv2.VideoCapture(media_path)
cap_h, cap_w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap_fps = int(cap.get(cv2.CAP_PROP_FPS))
print(cap_h, cap_w, cap_fps)
if need_record == "1":
# write mp4
fourcc = cv2.VideoWriter_fourcc(*'H264')
# 这里缩放了,因为后面写入时也缩放了
out = cv2.VideoWriter('output.mp4', fourcc, cap_fps, (cap_w//2, cap_h//2))
# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]
c_time = time.time()
while True:
ret, frame = cap.read()
if not ret:
print("video end")
break
# resize to half
frame = cv2.resize(frame, (cap_w//2, cap_h//2))
# cv2.imwrite("temp.jpg", frame)
# 检测
result = list(model(frame, conf=thresh, stream=True))[0] # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
boxes = result.boxes # Boxes object for bbox outputs
boxes = boxes.cpu().numpy() # convert to numpy array
# 遍历每个框
for box in boxes.data:
l,t,r,b = box[:4].astype(np.int32) # left, top, right, bottom
conf, id = box[4:] # confidence, class
id = int(id)
# 绘制框
cv2.rectangle(frame, (l,t), (r,b), class_color[id], 2)
# 绘制类别+置信度(格式:98.1%)
cv2.putText(frame, f"{objs_labels[id]} {conf*100:.1f}%", (l, t-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[id], 2)
# 遍历keypoints
keypoints = result.keypoints # Keypoints object for pose outputs
keypoints = keypoints.cpu().numpy() # convert to numpy array
# draw keypoints, set first keypoint is red, second is blue
for keypoint in keypoints.data:
for i in range(len(keypoint)):
x,y,c = keypoint[i]
x,y = int(x), int(y)
cv2.circle(frame, (x,y), 10, keypoint_color[i], -1)
cv2.putText(frame, f"{keypoint_list[i]}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)
if len(keypoint) >= 2:
# draw arrow line from tail to half between head and tail
x1,y1,c1 = keypoint[0]
x2,y2,c2 = keypoint[1]
center_x, center_y = (x1+x2)/2, (y1+y2)/2
cv2.arrowedLine(frame, (int(x2),int(y2)), (int(center_x), int(center_y)), (255,0,255), 4, line_type=cv2.LINE_AA, tipLength=0.1)
# 计算fps
n_time = time.time()
fps = 1/(n_time-c_time)
c_time = n_time
# 绘制fps
cv2.putText(frame, f"fps: {fps:.1f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
if need_record == "1":
out.write(frame)
if need_display == "1":
cv2.imshow("result", frame)
if cv2.waitKey(1) == ord('q'):
break
用法:
- 下载代码:点击下载
- 运行:
python 2.test_video.py weights/best.pt media/test.MOV 1 1 0.5