恩培-计算机视觉

 找回密码
 立即注册
搜索
查看: 1073|回复: 0

4.Yolov8物体关键点(文具、螺栓)

[复制链接]

144

主题

98

回帖

3537

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
3537
发表于 2023-11-29 15:08:57 | 显示全部楼层 |阅读模式

Yolov8物体关键点(文具、螺栓)

项目 内容
更新人 @恩培
更新时间 2023-11-29
标题 Yolov8物体关键点(文具、螺栓)
方法 标注、SAM分割一切、数据增强、半自动标注、训练
相关付费内容 八、项目八:yolov8 pose训练物体关键点(螺栓、文具)

[toc]

1. 视频/截图

https://www.bilibili.com/video/BV1Pa4y1f7kL/?spm_id_from=888.80997.embed_other.whitelist&t=1&vd_source=39b1662212679b11469d17d3bee8df4e

2. 核心代码

# 测试视频
from ultralytics import YOLO
import cv2
import numpy as np
import sys
import time

# 读取命令行参数
weight_path = sys.argv[1]
media_path = sys.argv[2]
need_record = sys.argv[3]
need_display = sys.argv[4]
thresh = sys.argv[5] if len(sys.argv) >= 6 else 0.3
# to float
thresh = float(thresh)

# 加载模型
model = YOLO(weight_path )

# 获取类别
objs_labels = model.names  # get class labels
print(objs_labels)

# 读取视频
cap = cv2.VideoCapture(media_path)
cap_h, cap_w = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
cap_fps = int(cap.get(cv2.CAP_PROP_FPS))
print(cap_h, cap_w, cap_fps)
if need_record == "1":
    # write mp4
    fourcc = cv2.VideoWriter_fourcc(*'H264')
    # 这里缩放了,因为后面写入时也缩放了
    out = cv2.VideoWriter('output.mp4', fourcc, cap_fps, (cap_w//2, cap_h//2))

# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]

c_time = time.time()
while True:
    ret, frame = cap.read()
    if not ret:
        print("video end")
        break
    # resize to half
    frame = cv2.resize(frame, (cap_w//2, cap_h//2))
    # cv2.imwrite("temp.jpg", frame)
    # 检测
    result = list(model(frame, conf=thresh, stream=True))[0]  # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
    boxes = result.boxes  # Boxes object for bbox outputs
    boxes = boxes.cpu().numpy()  # convert to numpy array

    # 遍历每个框
    for box in boxes.data:
        l,t,r,b = box[:4].astype(np.int32) # left, top, right, bottom
        conf, id = box[4:] # confidence, class
        id = int(id)
        # 绘制框
        cv2.rectangle(frame, (l,t), (r,b), class_color[id], 2)
        # 绘制类别+置信度(格式:98.1%)
        cv2.putText(frame, f"{objs_labels[id]} {conf*100:.1f}%", (l, t-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[id], 2)

    # 遍历keypoints
    keypoints = result.keypoints  # Keypoints object for pose outputs
    keypoints = keypoints.cpu().numpy()  # convert to numpy array

    # draw keypoints, set first keypoint is red, second is blue
    for keypoint in keypoints.data:
        for i in range(len(keypoint)):
            x,y,c = keypoint[i]
            x,y = int(x), int(y)
            cv2.circle(frame, (x,y), 10, keypoint_color[i], -1)
            cv2.putText(frame, f"{keypoint_list[i]}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)

        if len(keypoint) >= 2:
            # draw arrow line from tail to half between head and tail
            x1,y1,c1 = keypoint[0]
            x2,y2,c2 = keypoint[1]
            center_x, center_y = (x1+x2)/2, (y1+y2)/2
            cv2.arrowedLine(frame, (int(x2),int(y2)), (int(center_x), int(center_y)), (255,0,255), 4, line_type=cv2.LINE_AA, tipLength=0.1)
    # 计算fps
    n_time = time.time()
    fps = 1/(n_time-c_time)
    c_time = n_time
    # 绘制fps
    cv2.putText(frame, f"fps: {fps:.1f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    if need_record == "1":
        out.write(frame)

    if need_display == "1":

        cv2.imshow("result", frame)
        if cv2.waitKey(1) == ord('q'):
            break

用法:

  • 下载代码:点击下载
  • 运行:python 2.test_video.py weights/best.pt media/test.MOV 1 1 0.5
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

恩培-计算机视觉

GMT+8, 2024-5-13 16:49 , Processed in 0.075910 second(s), 21 queries .

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表