Deepseek 宣布连续5天发布一些最新的开源成果,今天开始第一天,今天发布的开源成果为:FlashMLA。 它通过结合先进的 MLA 算法和 GPU 优化技术,为大模型推理提供了高性能、低延迟的解码方案。
FlashMLA 是一个针对 Hopper GPU(例如 H800 SXM5)优化的高效 MLA(多头潜在注意力,Multi-head Latent Attention)解码内核。
FlashMLA 用来让GPU(特别是 NVIDIA 的 H800 这种高端显卡)更快地处理大模型的计算任务。它主要是用来加速一种叫做“多头潜在注意力”(MLA)的计算过程,这是大模型生成文字时的重要步骤。
FlashMLA 通过优化显卡的使用,让这个模型的推理过程更快、更省资源。比如,它能在一秒钟内处理海量数据(高达 3000 GB/s),或者达到很高的计算能力(580 TFLOPS,相当于每秒做几百亿亿次计算)。这对实时聊天、翻译或者其他需要快速响应的场景特别有用。
- 内存带宽提升了 2-3 倍(3000 GB/s vs. 1000-1500 GB/s)
- 计算性能提升了约 2 倍(580 TFLOPS vs. 200-300 TFLOPS)
- 推理速度提升 30%-50%,特别是在长序列和大规模推理场景中
- 显存利用率提升 20%-30%
它解决了什么问题?
在大模型(比如 ChatGPT)进行 推理(即生成文本)时,每一步的计算都会涉及到大量的 注意力计算,这个计算的 速度和显存占用 是影响推理效率的关键因素。
- 传统方法的 计算效率较低,会浪费 GPU 计算资源。
- 计算过程中 显存占用过高,导致在同样的硬件上,能处理的序列长度有限。
- 大部分 Transformer 计算库不适用于变长序列,而 FlashMLA 针对这个问题进行了优化。
FlashMLA 通过优化计算方式(Paged KV 缓存 + 高效 MLA 计算),提升推理速度,并减少显存占用。
它怎么做到快的?
- 聪明地用内存:它减少了显卡在计算时来回读写数据的时间(参考了 FlashAttention 的方法)。
- 灵活处理任务:不管输入的文字长短,它都能高效应对。
- 省力设计:通过预先准备好一些“计算计划”(元数据),让显卡直接按最优方案干活,不浪费时间。
FlashMLA 提升效果
FlashMLA 的优化为大模型推理提供了 显著的性能提升,特别是在 内存带宽 和 计算性能 上。以下是具体的提升量:
1. 内存带宽
- 传统方法: 在内存带宽受限时,大多数推理方法的带宽可能在 数百 GB/s 到 1500 GB/s 之间,具体取决于硬件配置。
- FlashMLA: 通过优化 分页 KV 缓存 和 高效内存访问模式,FlashMLA 能够在 内存受限配置下 实现 3000 GB/s 的带宽。这是 传统方法的 2 到 3 倍。
2. 计算性能
- 传统方法: 通常,在使用 NVIDIA H800 GPU 时,推理性能可能会限制在 200-300 TFLOPS,尤其是在计算受限的情况下。
- FlashMLA: 在 H800 SXM5 GPU 上,FlashMLA 的计算性能可以高达 580 TFLOPS,几乎是传统方法的两倍以上。
3. 推理速度与延迟
FlashMLA 通过 优化多头注意力计算,显著减少了每次推理所需的计算时间。对于 长序列 的推理任务,FlashMLA 的 延迟 要比传统方法低得多。这意味着:
- 在处理 变长序列 时,FlashMLA 能够以 更少的计算资源 完成任务。
- 响应时间 得到了大幅度缩短,尤其是在 大规模模型推理(如 ChatGPT 类型的模型)中,推理延迟 减少了 30%-50%。
4. 显存利用率
- FlashMLA 还通过 优化显存的使用,减少了每个推理步骤所需的显存占用,使得即便在相同的硬件配置下,能够处理 更长的输入序列。相比传统方法,显存利用率提升了 20%-30%,从而支持更大规模的推理任务。
代码里有什么?
项目里有两块核心内容:
- 一个功能帮你“规划任务”,告诉显卡怎么分配工作。
- 一个功能实际“干活”,用显卡算出结果,还能记住一些中间数据(键值缓存),下次用的时候更快。
主要特点
- 高性能表现
在内存受限配置下,FlashMLA 可实现高达 3000 GB/s 的吞吐量。
在计算受限配置下,达到 580 TFLOPS 的峰值性能(基于 H800 SXM5 和 CUDA 12.6)。
通过优化的内核设计,减少推理过程中的内存瓶颈,提升整体效率。
- 针对可变长度序列优化
- FlashMLA 支持动态序列长度,能够高效处理多样化的输入场景,特别适用于实时服务和批处理任务。
- 核心功能实现
项目提供了两个主要接口:
get\_mla\_metadata
:用于生成 MLA 的元数据,例如瓦片调度器元数据(tile_scheduler_metadata)和分割数量(num_splits),以适配不同的序列长度和注意力配置。flash\_mla\_with\_kvcache
:结合键值缓存(KV Cache)执行 MLA 计算,支持因果注意力(causal attention),并输出注意力结果和对数和指数(log-sum-exp)值。
- 技术灵感与依赖
- FlashMLA 借鉴了 FlashAttention 的内存高效注意力机制和 Cutlass 的高性能计算库特性,确保了其在 GPU 上的卓越表现。