Mamba: 使用选择性状态空间的线性时间序列建模
研究背景
- 当前深度学习的基础模型几乎都基于Transformer架构及其核心注意力模块
- 现有的次二次时间复杂度架构(如线性注意力、门控卷积、循环模型和结构化状态空间模型SSM)在处理长序列时虽能提高计算效率,但在语言等重要模态上表现不如注意力机制
关键创新点
- 识别核心问题: 发现这些模型的主要弱点是无法进行基于内容的推理
- 选择性SSM机制: 让SSM参数成为输入的函数,使模型能够根据当前token选择性地传播或遗忘序列信息
- 硬件感知算法: 设计了循环模式下的硬件感知并行算法,克服了无法使用高效卷积的限制
- 简化架构: 提出了不含注意力机制或MLP块的端到端神经网络架构Mamba
性能优势
- 推理速度: 吞吐量比Transformers高5倍
- 可扩展性: 序列长度呈线性扩展,可处理百万级长度序列
- 模型表现: Mamba-3B模型在预训练和下游评估中优于同等规模的Transformers,并可匹敌两倍规模的Transformers
- 跨模态能力: 在语言、音频和基因组学等多个模态上达到最先进性能
应用价值
作为通用序列模型骨干网络,Mamba为处理长序列数据提供了高效且性能优越的解决方案
https://arxiv.org/abs/2312.00752
Mamba-RS: Rust实现的Mamba选择性状态空间模型
项目概述
这是一个用Rust语言实现的Mamba SSM(选择性状态空间模型),支持可选的CUDA GPU加速。项目支持推理和训练,包括通过递归SSM状态的完整反向传播(BPTT),并提供自定义CUDA核心用于GPU加速的前向和反向传播。
主要特性
- 推理优化:零内存分配的单步递归前向传播(CPU上约200微秒)
- 完整训练支持:通过SSM隐藏状态的完整BPTT反向传播
- 预热机制:支持从历史上下文预热递归状态
- CUDA加速:为SSM递归、conv1d和融合激活函数提供自定义核心
- 独立运行:无需PyTorch、Burn或Candle等框架依赖
- 单精度浮点:原生f32,在Ampere/Hopper架构上支持TF32张量核心
架构设计
模型采用多层结构,每层包含:
- 输入投影和残差连接
- RMS归一化
- 门控机制和卷积层
- SSM递归计算(h = Ah + Bx, y = Ch + Dx)
- 输出投影和残差相加
性能表现
CPU推理(GH200 ARM):
- 小型配置(64维,2层):61微秒
- 默认配置(128维,3层):375微秒
- 大型配置(512维,6层):13.6毫秒
GPU性能(H100,TF32):批量推理延迟约10-25微秒
核心优势
- 手动推导的解析梯度,无需自动微分依赖
- 跨时间步的BPTT支持
- 零分配推理路径
- 扁平连续权重缓冲区,便于优化器融合
- 兼容CUDA Graph捕获
与Python版本的差异
| 特性 | Python版本 | Rust版本 |
|---|---|---|
| 反向传播 | PyTorch自动微分 | 手动BPTT |
| 核心实现 | Triton + CUDA C++ | CUDA NVRTC运行时编译 |
| 框架依赖 | PyTorch | 独立运行 |
| 精度 | fp16/bf16/fp32 | f32(GPU上TF32) |
许可证
采用MIT或Apache-2.0双重许可。
https://github.com/silvermpx/mamba-rs
--
From 日报小组 Mike
社区学习交流平台订阅:
1
共 0 条评论, 1 页
评论区
写评论还没有评论