yolov8-seg分割模型TensorRt部署,去掉torch

已完成的yolov8-seg分割模型TensorRt部署

  • 准备
  • 下载yolov8-seg模型
  • 转化为onnx和trt
  • 推理
    • 写好的推理接口

准备

https://github.com/songjiahao-wq/yolov8_seg_trtinference.git下载代码
安装TensorRt=8.6版本,以及pip install -r requirements.txt

下载yolov8-seg模型

转化为onnx和trt

转化方法如下:

# tensorRT==8.6
## yolov8-seg CLI指令
### 转化ONNX模型
`python export-seg.py --weights yolov8m-seg.pt --opset 14 --sim --input-shape 1 3 640 640 --device cuda:0`

`python export-seg.py --weights yolov8m-seg.pt --opset 14 --sim --input-shape 1 3 448 512 --device cuda:0`
### 导出trt模型
`python build.py --weights yolov8m-seg.onnx --fp16  --device cuda:0 --seg`
### 采用trtexec导出trt模型
`E:\Download\TensorRT-10.0.1.6\bin/trtexec --onnx=yolov8m-seg.onnx --saveEngine=yolov8s-seg.engine --fp16`
### 不需要torch环境推理
`python infer-seg-without-torch.py --engine yolov8m-seg.engine --imgs data --show --out-dir outputs --method cudart`
### 需要torch环境推理
`python infer-seg.py`


- [x] infer-seg-without-torch-port.py 调用接口,每次只保存mask.txt
- [x] infer-seg-without-torch.py 不需要torch调用,有cuda和pycuda

  1. 首先转化为onnx模型
python export-seg.py --weights yolov8m-seg.pt --opset 14 --sim --input-shape 1 3 640 640 --device cuda:0
  1. 然后转化为trt模型
    有两种转化方式:
    代码转化:python build.py --weights yolov8m-seg.onnx --fp16 --device cuda:0 --seg
    trtexec转化:trtexec --onnx=yolov8m-seg.onnx --saveEngine=yolov8s-seg.engine --fp16

推理

推理方法有两种:
cudart推理,不包含torch

  1. python infer-seg-without-torch.py --engine yolov8m-seg.engine --imgs data --show --out-dir outputs --method cudart
    pycuda推理,不包含torch
  2. `python infer-seg-without-torch.py --engine yolov8m-seg.engine --imgs data --show --out-dir outputs --method pycuda
    带torch的推理
  3. python infer-seg.py

写好的推理接口

import argparse
import time
from pathlib import Path

import cv2
import numpy as np

from config import ALPHA, CLASSES, COLORS, MASK_COLORS
from models.utils import blob, letterbox, path_to_list, seg_postprocess
import torch


def clip_segments(segments, shape):
    """Clips segment coordinates (xy1, xy2, ...) to an image's boundaries given its shape (height, width)."""
    if isinstance(segments, torch.Tensor):  # faster individually
        segments[:, 0].clamp_(0, shape[1])  # x
        segments[:, 1].clamp_(0, shape[0])  # y
    else:  # np.array (faster grouped)
        segments[:, 0] = segments[:, 0].clip(0, shape[1])  # x
        segments[:, 1] = segments[:, 1].clip(0, shape[0])  # y


def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
    """Rescales segment coordinates from img1_shape to img0_shape, optionally normalizing them with custom padding."""
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    segments[:, 0] -= pad[0]  # x padding
    segments[:, 1] -= pad[1]  # y padding
    segments /= gain
    clip_segments(segments, img0_shape)
    if normalize:
        segments[:, 0] /= img0_shape[1]  # width
        segments[:, 1] /= img0_shape[0]  # height
    return segments


def masks2segments(masks, strategy="largest"):
    """Converts binary (n,160,160) masks to polygon segments with options for concatenation or selecting the largest
    segment.
    """
    segments = []
    for x in masks.int().cpu().numpy().astype("uint8"):
        c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
        if c:
            if strategy == "concat":  # concatenate all segments
                c = np.concatenate([x.reshape(-1, 2) for x in c])
            elif strategy == "largest":  # select largest segment
                c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
        else:
            c = np.zeros((0, 2))  # no segments found
        segments.append(c.astype("float32"))
    return segments


def keep_highest_conf_per_class(bboxes, scores, labels, segments, classes=0):
    # 组合成新的检测结果数组
    det = np.hstack((bboxes, scores[:, np.newaxis], labels[:, np.newaxis], np.array(segments)[:, np.newaxis]))

    if det.shape[0] == 0:
        return det  # 如果没有检测到任何对象,直接返回

    unique_classes = np.unique(det[:, 5])  # 获取所有独特的类标签
    max_conf_indices = []

    # 对每一个类别找到最高置信度的检测框
    cls_mask = det[:, 5] == classes  # 找到所有该类别的检测框
    cls_detections = det[cls_mask]  # 提取该类别的所有检测框
    # 计算每个检测框的面积
    areas = (cls_detections[:, 2] - cls_detections[:, 0]) * (
            cls_detections[:, 3] - cls_detections[:, 1])
    # 合并置信度和面积为一个复合评分,这里用置信度 + 面积的小部分作为评分
    scores_combined = cls_detections[:, 4] * 0.1 + 1.0 * areas
    # 找到评分最高的检测框
    max_score_index = np.argmax(scores_combined)
    # 找到原始的索引
    original_max_conf_index = np.where(cls_mask)[0][max_score_index]
    max_conf_indices.append(original_max_conf_index)
    # 选取评分最高的检测框
    return det[max_conf_indices][:, :4], det[max_conf_indices][:, 4], det[max_conf_indices][:, 5], det[
                                                                                                       max_conf_indices][
                                                                                                   :,
                                                                                                   6], max_conf_indices


class YOLOv8_seg_main:
    def __init__(self, args: argparse.Namespace):
        if args.method == 'cudart':
            from models.cudart_api import TRTEngine
        elif args.method == 'pycuda':
            from models.pycuda_api import TRTEngine
        else:
            raise NotImplementedError
        self.Engine = TRTEngine(args.engine)
        self.H, self.W = self.Engine.inp_info[0].shape[-2:]
        self.args = args

    def main(self, bgr, imagename, outtxtdir) -> None:
        outtxtdir = Path(outtxtdir)
        save_path = Path(args.out_dir)

        if not self.args.show and not save_path.exists():
            save_path.mkdir(parents=True, exist_ok=True)
        draw = bgr.copy()
        bgr, ratio, dwdh = letterbox(bgr, (self.W, self.H))
        dw, dh = int(dwdh[0]), int(dwdh[1])
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        tensor, seg_img = blob(rgb, return_seg=True)
        dwdh = np.array(dwdh * 2, dtype=np.float32)
        tensor = np.ascontiguousarray(tensor)
        # inference
        data = self.Engine(tensor)
        seg_img = seg_img[dh:self.H - dh, dw:self.W - dw, [2, 1, 0]]
        bboxes, scores, labels, masks = seg_postprocess(
            data, bgr.shape[:2], self.args.conf_thres, self.args.iou_thres)
        if bboxes.size == 0:
            # if no bounding box
            assert print(f'image: no object!')
        masks = masks[:, dh:self.H - dh, dw:self.W - dw, :]
        segments = [
            scale_segments(tensor.shape[2:], x, rgb.shape, normalize=True)
            for x in reversed(masks2segments(torch.from_numpy(masks)))
        ]

        bboxes -= dwdh
        bboxes /= ratio

        # 应用 keep_highest_conf_per_class 函数
        bboxes, scores, labels, segments, max_conf_indices = keep_highest_conf_per_class(bboxes, scores, labels, segments, classes=0)
        if args.show:
            masks = masks[max_conf_indices]
            mask_colors = MASK_COLORS[0]
            mask_colors = mask_colors.reshape(-1, 1, 1, 3) * ALPHA
            mask_colors = masks @ mask_colors
            inv_alph_masks = (1 - masks * 0.5).cumprod(0)
            mcs = (mask_colors * inv_alph_masks).sum(0) * 2
            seg_img = (seg_img * inv_alph_masks[-1] + mcs) * 255
            draw = cv2.resize(seg_img.astype(np.uint8), draw.shape[:2][::-1])

        if args.save_txt:
            seg = segments[0].reshape(-1)  # (n,2) to (n*2)
            line = (int(labels[0]), *seg)  # label format
            with open(outtxtdir / f"{Path(imagename).stem}.txt", "w") as f:
                f.write(("%g " * len(line)).rstrip() % line + "\n")

        if args.show:
            save_image = save_path / Path(imagename).name
            cv2.imwrite(str(save_image), draw)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--engine', type=str, default="../yolov8l-seg.engine", help='Engine file')
    parser.add_argument('--imgs', type=str, default="data", help='Images file')
    parser.add_argument('--show',
                        action='store_true',
                        default=False,
                        help='Show the detection results')
    parser.add_argument('--save_txt',
                        action='store_true',
                        default=True,
                        help='save_txt the detection results')
    parser.add_argument('--out-dir',
                        type=str,
                        default='./output',
                        help='Path to output file')
    parser.add_argument('--conf-thres',
                        type=float,
                        default=0.25,
                        help='Confidence threshold')
    parser.add_argument('--iou-thres',
                        type=float,
                        default=0.25,
                        help='Confidence threshold')
    parser.add_argument('--method',
                        type=str,
                        default='cudart',
                        help='CUDART pipeline')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    YOLOv8_seg_main = YOLOv8_seg_main(args)
    imgpath = './data/1.jpg'
    outtxtdir = './output'
    bgr_img = cv2.imread(imgpath)
    t1 = time.time()
    for i in range(100):
        YOLOv8_seg_main.main(bgr_img, imgpath, outtxtdir)
    print(time.time() - t1)

输入为brg图像,图像的路径和输出路径,最后会保存masktxt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/774336.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kafka系列之Kafka知识超强总结

一、Kafka简介 Kafka是什么 Kafka是一种高吞吐量的分布式发布订阅消息系统(消息引擎系统),它可以处理消费者在网站中的所有动作流数据。 这种动作(网页浏览, 搜索和其他用户的行动)是在现代网络上的许多社…

个人引导页+音乐炫酷播放器(附加源码)

个人引导页音乐炫酷播放器 效果图部分源码完整源码领取下期更新内容 效果图 部分源码 //网站动态标题开始 var OriginTitile document.title, titleTime; document.addEventListener("visibilitychange", function() {if (document.hidden) {document.title "…

为什么英智智能宝能让律师工作事半功倍

大语言模型能够极大提高人们的知识理解能力和知识服务能力,法律服务是典型的知识服务领域,据悉律师有38%的任务都是重复性工作,这些任务有潜力被大模型替代。 但在法律行业中的高度专业且复杂的问题时,通用型大模型的回答虽能提供…

Dungeonborne卡顿怎么办 快速解决Dungeonborne卡顿问题

随着Dungeonborne游戏剧情的深入,玩家将逐渐解锁更多的地图和副本,每个区域都有其独特的生态和敌人。在探索的过程中,玩家不仅可以获得强大的装备和道具,还能结识到志同道合的伙伴,共同面对更强大的敌人。不过也有玩家…

谷粒商城学习笔记-05-项目微服务划分图

文章目录 一,商城业务服务-前端服务二,商城业务服务-后端服务三,存储服务四,第三方服务五,服务治理六,日志七,监控预警系统1,Prometheus2,Grafana3,Prometheu…

奥能电源应邀参加2024年顺丰创π创新大会

企业动态|杭州奥能董事长陈虹先生和常务副总金晖女士受邀出席创π-产业科技创新大会,深入探讨“双碳”目标下的产业转型与技术创新 近日,杭州奥能董事长陈虹先生和常务副总金晖女士应邀出席了在杭州举办的创π-产业科技创新大会。本次大会以产…

嵌入式学习——硬件(UART)——day55

1. UART 1.1 定义 UART(Universal Asynchronous Receiver/Transmitter,通用异步收发器)是一种用于串行通信的硬件设备或模块。它的主要功能是将数据在串行和并行格式之间进行转换。UART通常用于计算机与外围设备或嵌入式系统之间的数据传输。…

Git仓库介绍

1. Github GitHub 本身是一个基于云端的代码托管平台,它提供的是远程服务,而不是一个可以安装在本地局域网的应用程序。因此,GitHub 不可以直接在本地局域网进行安装。 简介:GitHub是最流行的代码托管平台,提供了大量…

【Android源码】Gerrit上传Android源码

关于Gerrit的安装参考下面链接 【Android源码】Gerrit安装 要实现上传Android源码,需要经历以下几步: 下载Android代码创建源码仓库创建manifests仓库上传源码其他电脑下载源码 要证明Gerrit中的源码真实可用,肯定是以其他人能真正共享到代…

C++(第五天----多继承、虚继承、虚函数、虚表)

一、继承对象的内存空间 构造函数调用顺序&#xff0c;先调用父类&#xff0c;再调用子类 #include<iostream>using namespace std;//基类 父类 class Base{ public: //公有权限 类的外部 类的内部 Base(){cout<<"Base()"<<endl;}Base(int …

笔记本电脑升级实战手册[2]:清灰换硅脂

文章目录 前言&#xff1a;一、开盖拆卸二、清灰指南1. 电脑内部清灰2. 风扇清灰3. 清理散热铜管 三、更换硅脂总结&#xff1a; 前言&#xff1a; 这是笔记本电脑升级实战手册的第二篇文章&#xff0c;本篇主要是对电脑进行清灰换硅脂的处理的分享&#xff0c;使用电脑是华硕…

晨持绪电商:大学毕业生投资抖音网店怎么样

在这个数字化飞速发展的时代&#xff0c;传统的职业路径已不再是唯一的选择。对于充满激情和创意的大学毕业生来说&#xff0c;投资抖音网店或许是一个颇具前景的选择。 抖音作为一个流量巨大的社交媒体平台&#xff0c;为年轻人提供了一个展示自我、推广产品的绝佳舞台。与传统…

创新引领,构筑产业新高地

在数字经济的浪潮中&#xff0c;成都树莓集团以创新驱动为核心&#xff0c;通过整合行业资源、优化服务、培养数字产业人才等措施&#xff0c;致力于打造产业高地&#xff0c;推动地方经济的高质量发展。 一、创新驱动&#xff0c;引领产业发展 1、引入新技术、新模式&#xf…

平安养老险宿州中心支公司积极参与“78奋力前行”集体健步行活动

7月3日&#xff0c;平安养老保险股份有限公司&#xff08;以下简称“平安养老险”&#xff09;宿州中心支公司组织员工参加由宿州市保险行业协会2024年“78奋力前行”线下集体健步行活动。 平安养老险宿州中心支公司员工高举公司旗帜&#xff0c;与同业伙伴一起出发&#xff0…

maven设置阿里云镜像源(加速)

一、settings.xml介绍 settings.xml是maven的全局配置文件&#xff0c;maven的配置文件存在三个地方 项目中的pom.xml&#xff0c;这个是pom.xml所在项目的局部配置文件用户配置&#xff1a;${user.home}/.m2/settings.xml全局配置&#xff1a;${M2_HOME}/conf/settings.xml 优…

数据库国产化之路(一)

数据库国产化之路(一) 1、前言&#xff1a;适配海量数据库过程中的一些记录&#xff0c;备忘用 2、海量数据库基于的pg版本&#xff0c;查看PG_VERSION文件为9.2。 3、MySQL中的IF函数替代&#xff0c;一开始的方案是从网上找了个if函数&#xff0c;后来发现CASE WHEN其实能完成…

手把手教你生成一幅好看的AI图片

很多人看到别人用SD生成出来的图片感到非常的羡慕&#xff0c;因为即使给了他们最好的SD软件&#xff0c;他们也是词穷&#xff0c;不知道该如何去描述要生成的图片。 别急&#xff0c;这篇文章会一步步的教会你怎么才能生成一个好看的AI图片。 跟着我&#xff0c;别走丢。 …

iptables与firewalld

iptables Linux上常用的防火墙软件 1、 防火墙的策略 防火墙策略一般分为两种&#xff0c;一种叫通策略&#xff0c;一种叫堵策略&#xff0c;通策略&#xff0c;默认门是关着的&#xff0c;必须要定义谁能进。堵策略则是&#xff0c;大门是洞开的&#xff0c;但是你必须有身…

从数据到智能,英智私有大模型助力企业实现数智化发展

在数字化时代&#xff0c;数据已经成为企业最重要的资源。如何将这些数据转化为实际的业务价值&#xff0c;是每个企业面临的重要课题。英智利用业界领先的清洗、训练和微调技术&#xff0c;对企业数据进行深度挖掘和分析&#xff0c;定制符合企业业务场景的私有大模型&#xf…

2024年7月最佳免费天气API接口推荐

在我们的日常生活中&#xff0c;天气扮演着一个至关重要的角色&#xff0c;它影响着我们的情绪、健康、日常安排和商业决策。无论是计划一次户外活动、安排农作物种植&#xff0c;还是确保旅行安全&#xff0c;天气信息的准确性和及时性至关重要。随着技术的进步&#xff0c;天…