Transformer实战-系列教程16:DETR 源码解读3(DETR类)

news/2024/2/21 2:57:20

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

4、DETR类

位置:models/detr.py/DETR类

4.1 构造函数

class DETR(nn.Module):def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_modelself.class_embed = nn.Linear(hidden_dim, num_classes + 1)self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)self.query_embed = nn.Embedding(num_queries, hidden_dim)self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_loss
  1. DETR类继承torch nn.Module
  2. 构造函数,传入5个参数:
    • backbone:CNN骨架网络,用于特征提取
    • transformer:Transformer模型,用于处理序列数据
    • num_classes:目标类别的数量
    • num_queries:解码器初始化生成的100个向量的个数,num_queries=100
    • aux_loss:一个布尔值,指示是否使用辅助损失来帮助训练
  3. 初始化
  4. num_queries
  5. transformer
  6. hidden_dim ,Transformer中的隐藏层维度
  7. class_embed ,类别预测的输出层,这个全连接层是接Transformer的输出,类别加1是额外的无类别对象
  8. bbox_embed,一个MLP,也是接Transformer的输出,边界框的四个坐标的回归
  9. query_embed ,解码器的初始100个向量
  10. input_proj ,一个1x1的二维卷积,使得backbone的输出通道数映射到与Transformer隐藏层维度相同
  11. backbone,一个预训练的卷积神经网络,主要作用是提取图像的特征,它的输出经过input_proj 处理后作为Transformer的输入
  12. aux_loss,保存是否使用辅助损失的标志

这里包含了几个自定义函数和类:
nested_tensor_from_tensor_list函数:将不同尺寸处理的图像Tensor转换为一个嵌套Tensor
MLP类:边界框的四个坐标的回归
transformer类:构建transformer架构
backbone:用于提取图像特征的CNN

4.2 前向传播

    def forward(self, samples: NestedTensor):if isinstance(samples, (list, torch.Tensor)):samples = nested_tensor_from_tensor_list(samples)features, pos = self.backbone(samples)src, mask = features[-1].decompose()assert mask is not Nonehs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out    
  1. 前向传播函数,输入为samples=NestedTensor{mask={Tensor(2,771,911)},tensors={Tensor(2,3,771,911)}}
  2. 检查samples是否为列表或Tensor类型
  3. samples ,如果,使用nested_tensor_from_tensor_list函数转换为NestedTensor
  4. features, pos,图像特征图列表和对应的位置编码列表,backbone实际上一个现在的resnet
  5. src, mask,解构最后一层的特征,获取源数据和掩码,src:torch.Size([2, 2048, 21, 18]),mask torch.Size([2, 21, 18]),2是batch,2048是特征维度,后面两个是图像长宽
  6. 确保掩码不为空
  7. 将数据通过Transformer处理,获取序列输出,torch.Size([6, 2, 100, 256]),6是Transformer的堆叠层数,2是batch,100是生成100个目标预测,256是每个目标预测的维度
  8. outputs_class ,获取类别预测
  9. outputs_coord ,获取边界框坐标预测,并使用sigmoid函数将输出值限制在0到1之间
  10. out ,将类别预测结果和 边界框坐标预测结果做成一个字典
  11. 如果启用了辅助损失
  12. 通过辅助函数_set_aux_loss计算辅助损失
  13. 返回out

4.3 辅助函数_set_aux_loss()

@torch.jit.unuseddef _set_aux_loss(self, outputs_class, outputs_coord):return [{'pred_logits': a, 'pred_boxes': b}for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  1. @torch.jit.unused:一个装饰器,指示当使用TorchScript编译模型时,该方法不应被编译。这是因为辅助损失的计算可能不兼容TorchScript的静态图特性
  2. 定义函数,接收类别预测和边界框坐标作为输入
  3. 返回一个列表,将每一个类别预测和边界框坐标都封装成一个字典,这样,训练过程中可以计算每一层的损失,从而实现辅助损失的目的

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)


https://www.xjx100.cn/news/3271416.html

相关文章

单页404源码

<!doctype html> <html> <head> <meta charset"utf-8"> <title>简约 404错误页</title><link rel"shortcut icon" href"./favicon.png"><style> import url("https://fonts.googleapis.co…

c++恶魔轮盘制造第1期输赢

小常识&#xff0c;恶魔叫DEALER&#xff0c;上帝叫God. 赢了很简单 void sheng() { cout<<"你获胜了&#xff01;";MessageBox(NULL,TEXT("你的钱~~~~~~给你"),TEXT("DEALER"),MB_OK);system("pause");system("cls"…

powershell 雅地关闭UDP监听器

在PowerShell中优雅地关闭UDP监听器意味着你需要一种机制来安全地停止正在运行的UdpClient实例。由于UdpClient类本身没有提供直接的停止或关闭方法&#xff0c;你需要通过其他方式来实现这一点。通常&#xff0c;这涉及到在监听循环中添加一个检查点&#xff0c;以便在接收到停…

使用 C++23 从零实现 RISC-V 模拟器(1):最简CPU

&#x1f449;&#x1f3fb; 文章汇总「从零实现模拟器、操作系统、数据库、编译器…」&#xff1a;https://okaitserrj.feishu.cn/docx/R4tCdkEbsoFGnuxbho4cgW2Yntc 本节实现一个最简的 CPU &#xff0c;最终能够解析 add 和 addi 两个指令。如果对计算机组成原理已经有所了…

MySQL篇----第十四篇

系列文章目录 文章目录 系列文章目录前言一、MySQL 数据库作发布系统的存储,一天五万条以上的增量,预计运维三年,怎么优化?二、锁的优化策略三、索引的底层实现原理和优化四、什么情况下设置了索引但无法使用前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽…

模态、模式和真实发生

模态和模式均是用来描述某一对象或系统可能出现的特性、状态或行为&#xff0c;它们既包括逻辑上的抽象可能性&#xff0c;也涵盖现实中具体的现象和事件结构。模态更多地关联于逻辑可能性和必然性&#xff0c;而模式则侧重于现象的重复性和规律性&#xff0c;两者都可以反映真…

【Django】Django项目部署

项目部署 1 基本概念 项目部署是指在软件开发完毕后&#xff0c;将开发机器上运行的软件实际安装到服务器上进行长期运行。 在安装机器上安装和配置同版本的环境[python&#xff0c;数据库等] django项目迁移 scp /home/euansu/Code/Python/website euansuxx.xx.xx.xx:/home…

Java并发之ThreadLocal理解

Java并发之ThreadLocal理解 介绍使用场景 介绍 ThreadLocal是为实现对资源对象的线程隔离&#xff0c;使每个线程拥有自己的资源&#xff0c;避免并发时争用引发线程安全问题 实现原理&#xff1a; 主要是其内部存在一个ThreadLocalMap存储资源&#xff0c;将ThreadLocal对象自…