Deepseek 宣布连续5天发布一些最新的开源成果,今天开始第一天,今天发布的开源成果为:FlashMLA。 它通过结合先进的 MLA 算法和 GPU 优化技术,为大模型推理提供了高性能、低延迟的解码方案。

FlashMLA 是一个针对 Hopper GPU(例如 H800 SXM5)优化的高效 MLA(多头潜在注意力,Multi-head Latent Attention)解码内核。

FlashMLA 用来让GPU(特别是 NVIDIA 的 H800 这种高端显卡)更快地处理大模型的计算任务。它主要是用来加速一种叫做“多头潜在注意力”(MLA)的计算过程,这是大模型生成文字时的重要步骤。

FlashMLA 通过优化显卡的使用,让这个模型的推理过程更快、更省资源。比如,它能在一秒钟内处理海量数据(高达 3000 GB/s),或者达到很高的计算能力(580 TFLOPS,相当于每秒做几百亿亿次计算)。这对实时聊天、翻译或者其他需要快速响应的场景特别有用。

  • 内存带宽提升了 2-3 倍(3000 GB/s vs. 1000-1500 GB/s)
  • 计算性能提升了约 2 倍(580 TFLOPS vs. 200-300 TFLOPS)
  • 推理速度提升 30%-50%,特别是在长序列和大规模推理场景中
  • 显存利用率提升 20%-30%

它解决了什么问题?

在大模型(比如 ChatGPT)进行 推理(即生成文本)时,每一步的计算都会涉及到大量的 注意力计算,这个计算的 速度和显存占用 是影响推理效率的关键因素。

  1. 传统方法的 计算效率较低,会浪费 GPU 计算资源。
  2. 计算过程中 显存占用过高,导致在同样的硬件上,能处理的序列长度有限。
  3. 大部分 Transformer 计算库不适用于变长序列,而 FlashMLA 针对这个问题进行了优化。

FlashMLA 通过优化计算方式(Paged KV 缓存 + 高效 MLA 计算),提升推理速度,并减少显存占用。

它怎么做到快的?

  • 聪明地用内存:它减少了显卡在计算时来回读写数据的时间(参考了 FlashAttention 的方法)。
  • 灵活处理任务:不管输入的文字长短,它都能高效应对。
  • 省力设计:通过预先准备好一些“计算计划”(元数据),让显卡直接按最优方案干活,不浪费时间。

FlashMLA 提升效果

FlashMLA 的优化为大模型推理提供了 显著的性能提升,特别是在 内存带宽计算性能 上。以下是具体的提升量:

1. 内存带宽

  • 传统方法: 在内存带宽受限时,大多数推理方法的带宽可能在 数百 GB/s1500 GB/s 之间,具体取决于硬件配置。
  • FlashMLA: 通过优化 分页 KV 缓存高效内存访问模式,FlashMLA 能够在 内存受限配置下 实现 3000 GB/s 的带宽。这是 传统方法的 2 到 3 倍

2. 计算性能

  • 传统方法: 通常,在使用 NVIDIA H800 GPU 时,推理性能可能会限制在 200-300 TFLOPS,尤其是在计算受限的情况下。
  • FlashMLA:H800 SXM5 GPU 上,FlashMLA 的计算性能可以高达 580 TFLOPS几乎是传统方法的两倍以上

3. 推理速度与延迟

  • FlashMLA 通过 优化多头注意力计算,显著减少了每次推理所需的计算时间。对于 长序列 的推理任务,FlashMLA 的 延迟 要比传统方法低得多。这意味着:

    • 在处理 变长序列 时,FlashMLA 能够以 更少的计算资源 完成任务。
    • 响应时间 得到了大幅度缩短,尤其是在 大规模模型推理(如 ChatGPT 类型的模型)中,推理延迟 减少了 30%-50%

4. 显存利用率

  • FlashMLA 还通过 优化显存的使用,减少了每个推理步骤所需的显存占用,使得即便在相同的硬件配置下,能够处理 更长的输入序列。相比传统方法,显存利用率提升了 20%-30%,从而支持更大规模的推理任务。

代码里有什么?

项目里有两块核心内容:

  1. 一个功能帮你“规划任务”,告诉显卡怎么分配工作。
  2. 一个功能实际“干活”,用显卡算出结果,还能记住一些中间数据(键值缓存),下次用的时候更快。

主要特点

  1. 高性能表现
  • 在内存受限配置下,FlashMLA 可实现高达 3000 GB/s 的吞吐量。

  • 在计算受限配置下,达到 580 TFLOPS 的峰值性能(基于 H800 SXM5 和 CUDA 12.6)。

  • 通过优化的内核设计,减少推理过程中的内存瓶颈,提升整体效率。

  1. 针对可变长度序列优化
  • FlashMLA 支持动态序列长度,能够高效处理多样化的输入场景,特别适用于实时服务和批处理任务。
  1. 核心功能实现
  • 项目提供了两个主要接口:

  • get\_mla\_metadata:用于生成 MLA 的元数据,例如瓦片调度器元数据(tile_scheduler_metadata)和分割数量(num_splits),以适配不同的序列长度和注意力配置。

  • flash\_mla\_with\_kvcache:结合键值缓存(KV Cache)执行 MLA 计算,支持因果注意力(causal attention),并输出注意力结果和对数和指数(log-sum-exp)值。

  1. 技术灵感与依赖
  • FlashMLA 借鉴了 FlashAttention 的内存高效注意力机制和 Cutlass 的高性能计算库特性,确保了其在 GPU 上的卓越表现。

GitHub:https://github.com/deepseek-ai/FlashMLA