PyTorch模型转ONNX指南

作者:杨志

问题背景与核心要素

在当前的机器学习领域,模型的研发与部署是两个紧密相连却又各具挑战的环节。PyTorch 以其灵活性、易用性和强大的社区支持,成为学术研究和模型训练的首选框架之一。然而,当模型从研究阶段走向实际应用部署时,开发者往往面临诸多挑战。

引言:为什么需要模型转换?

PyTorch:作为一款以 Python 优先,支持动态计算图的深度学习框架,PyTorch 极大地简化了复杂模型的构建和快速迭代过程。其完善的生态系统和丰富的预训练模型库,使其在科研领域占据主导地位。 PyTorch官方网站

模型部署的挑战:训练完成的 PyTorch 模型(通常为 .pth.pt 文件)在部署时会遇到以下问题:

ONNX (Open Neural Network Exchange) 的诞生:为了解决上述模型互操作性问题,ONNX 应运而生。它是一种为机器学习模型设计的开放标准格式。 ONNX官方网站

ONNX 的核心价值在于实现"一次训练,多处部署"。它充当了不同深度学习框架和硬件加速器之间的桥梁,允许开发者将一个框架中训练的模型导出为 ONNX 格式,然后在另一个支持 ONNX 的框架或硬件上进行推理。 PyTorch ONNX导出教程

核心问题提炼:

本文目标与结构概述:

本文旨在以 sherpa-onnx 项目中用于导出 Whisper 模型的 export-onnx.py 脚本为例,深入剖析从 PyTorch 模型到 ONNX 格式的转换流程。我们将详细探讨其中的关键技术点、参数选择、常见问题及其解决方案。通过阅读本文,读者不仅能理解通用的模型转换原理,还能掌握针对特定复杂模型(如语音识别领域的 Transformer 模型)的导出技巧和最佳实践。

文章结构将围绕以下几个核心部分展开:ONNX 与 PyTorch torch.onnx.export 函数概览、sherpa-onnx Whisper 导出脚本的深度解读、PyTorch 模型转 ONNX 的进阶技巧与常见问题、导出后验证与使用,最后进行总结并提供最佳实践建议。

ONNX与PyTorch torch.onnx.export 概览

ONNX 基础

定义:ONNX (Open Neural Network Exchange) 是一种用于表示机器学习模型的开放格式。它旨在促进不同人工智能框架之间的互操作性。 ONNX简介

组成:一个 ONNX 模型主要由以下部分构成:

优势

ONNX Runtime:由微软开发并开源的高性能推理引擎,专门用于执行 ONNX 模型。它支持多种硬件加速器和操作系统,能够显著提升模型在CPU和GPU上的推理速度。 ONNX Runtime官方网站

PyTorch torch.onnx.export 函数详解

PyTorch 通过 torch.onnx 模块提供将模型导出为 ONNX 格式的功能,其核心函数是 torch.onnx.export()

核心功能:该函数通过执行一次模型(使用提供的示例输入),记录下计算过程中涉及的 PyTorch 算子操作轨迹,然后将这些轨迹转换为符合 ONNX规范的计算图。 torch.onnx官方文档

关键参数解析:

导出器版本 (Exporter Versions):

PyTorch 的 ONNX 导出器经历了发展,主要有两个版本:

导出前置准备:

一个至关重要的步骤是在调用 torch.onnx.export() 之前,将模型设置为评估模式:

model.eval()

或者等效的 model.train(False)。这是因为像 Dropout 层和 BatchNorm 层这样的模块在训练模式和评估模式下的行为是不同的。在导出用于推理的模型时,必须确保它们处于评估(推理)状态,否则可能导致导出的 ONNX 模型行为不正确或性能不佳。 微软PyTorch转ONNX教程强调model.eval()

案例分析:export-onnx.py 脚本解读

以一个将 Whisper 模型从 PyTorch 实现转换为 ONNX 格式的脚本 export-onnx.py 为例,我们可以深入了解复杂 Transformer 模型导出的具体实现细节。这个脚本来自开源项目,是理解复杂模型导出流程的典型案例。

案例中使用的模型简介

本案例中的脚本主要处理 Whisper 模型,以下是相关背景信息:

Whisper 模型: 由 OpenAI 开发的基于 Transformer 架构的多语言语音识别模型。该模型在大量弱监督数据上训练,具有较高的识别准确率和跨语言性能。 OpenAI Whisper介绍

Esperanto Technologies 在一篇博客中详细讨论了 Whisper 模型的架构及其到 ONNX 的转换,指出其遵循典型的 Transformer 结构。 Adapting Whisper Models to ET-SoC-1 Architecture and Exporting Them to ONNX

export-onnx.py 脚本目标与设计

分析的 export-onnx.py 脚本的核心目标是将 Whisper 模型的 PyTorch 版本转换为能被 ONNX 运行时加载和使用的模型文件。

核心设计理念

脚本关键步骤与技术细节分析

根据提供的完整export-onnx.py源码,我们可以系统分析其核心实现和关键技术点。

1. 核心类与结构分析

为了适应ONNX导出,脚本对Whisper的原始实现进行了多项修改和封装。以下是几个关键类的详细分析:

  1. AudioEncoderTensorCache
    class AudioEncoderTensorCache(nn.Module):
        def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
            super().__init__()
            self.audioEncoder = inAudioEncoder
            self.textDecoder = inTextDecoder
    
        def forward(self, x: Tensor):
            audio_features = self.audioEncoder(x)
    
            n_layer_cross_k_list = []
            n_layer_cross_v_list = []
            for block in self.textDecoder.blocks:
                n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
                n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
    
            return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
    该类将Whisper的Encoder组件封装为能够提前生成cross-attention所需key和value的模块。其核心功能是:
    • 首先通过audioEncoder处理音频梅尔频谱图输入,得到audio_features
    • 对于Decoder中的每个Transformer层,预先计算cross-attention(跨注意力)的key和value投影。
    • 返回两个张量,分别为所有层的cross-attention keys和values的堆叠结果。
    这种设计使得Encoder的计算结果可以被缓存并重复用于Decoder的每一步推理,而不需要重复计算,优化了推理效率。
  2. MultiHeadAttentionCross
    class MultiHeadAttentionCross(nn.Module):
        def __init__(self, inMultiHeadAttention: MultiHeadAttention):
            super().__init__()
            self.multiHeadAttention = inMultiHeadAttention
    
        def forward(
            self,
            x: Tensor,
            k: Tensor,
            v: Tensor,
            mask: Optional[Tensor] = None,
        ):
            q = self.multiHeadAttention.query(x)
            wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
            return self.multiHeadAttention.out(wv)
    该类专门处理Decoder中的跨注意力机制,使其接口适应ONNX导出需求:
    • 将原始MultiHeadAttention类进行包装,使其可以接受预先计算好的key(k)和value(v)作为输入
    • 首先对输入x计算query投影,然后利用传入的预计算key和value执行注意力计算
    • 最后应用输出投影并返回结果
    这种设计支持Encoder和Decoder分离导出模式,同时允许复用预计算的key和value。
  3. MultiHeadAttentionSelf
    class MultiHeadAttentionSelf(nn.Module):
        def __init__(self, inMultiHeadAttention: MultiHeadAttention):
            super().__init__()
            self.multiHeadAttention = inMultiHeadAttention
    
        def forward(
            self,
            x: Tensor,  # (b, n_ctx      , n_state)
            k_cache: Tensor,  # (b, n_ctx_cache, n_state)
            v_cache: Tensor,  # (b, n_ctx_cache, n_state)
            mask: Tensor,
        ):
            q = self.multiHeadAttention.query(x)  # (b, n_ctx, n_state)
            k = self.multiHeadAttention.key(x)  # (b, n_ctx, n_state)
            v = self.multiHeadAttention.value(x)  # (b, n_ctx, n_state)
    
            k_cache[:, -k.shape[1] :, :] = k  # (b, n_ctx_cache + n_ctx, n_state)
            v_cache[:, -v.shape[1] :, :] = v  # (b, n_ctx_cache + n_ctx, n_state)
    
            wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
            return self.multiHeadAttention.out(wv), k_cache, v_cache
    该类专门处理Decoder中的自注意力机制,并实现了KV缓存功能:
    • 接受当前输入x及其对应的key和value缓存(k_cache和v_cache)作为输入
    • 计算当前输入的query, key和value投影
    • 将新计算的k和v更新到对应缓存中(注意其更新策略是在缓存的尾部进行替换)
    • 利用query和完整的key/value缓存计算注意力结果
    • 返回注意力输出以及更新后的key和value缓存
    这种KV缓存机制是实现高效自回归解码的关键,避免了重复计算已处理的token对应的key和value投影。
  4. ResidualAttentionBlockTensorCache
    class ResidualAttentionBlockTensorCache(nn.Module):
        def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
            super().__init__()
            self.originalBlock = inResidualAttentionBlock
            self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
            self.cross_attn = (
                MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
                if inResidualAttentionBlock.cross_attn
                else None
            )
    
        def forward(
            self,
            x: Tensor,
            self_k_cache: Tensor,
            self_v_cache: Tensor,
            cross_k: Tensor,
            cross_v: Tensor,
            mask: Tensor,
        ):
            self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
                self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
            )
            x = x + self_attn_x
    
            if self.cross_attn:
                x = x + self.cross_attn(
                    self.originalBlock.cross_attn_ln(x), cross_k, cross_v
                )
    
            x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
            return x, self_k_cache_updated, self_v_cache_updated
    该类封装了Transformer Decoder块的完整功能,包括自注意力、跨注意力和前馈网络:
    • 同时使用前面定义的MultiHeadAttentionSelfMultiHeadAttentionCross类来处理自注意力和跨注意力
    • 实现了完整的残差连接和层归一化,与原始Transformer结构一致
    • 将自注意力的KV缓存管理集成到forward过程中
    • 保持了原始前馈网络(MLP)的计算不变
    通过这种封装,Transformer块可以无缝支持KV缓存和预计算的跨注意力keys/values,同时保留原始计算逻辑。
  5. TextDecoderTensorCache
    class TextDecoderTensorCache(nn.Module):
        def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
            super().__init__()
            self.textDecoder = inTextDecoder
            self.n_ctx = in_n_ctx
    
            self.blocks = []
            for orginal_block in self.textDecoder.blocks:
                self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
    
        def forward(
            self,
            tokens: Tensor,
            n_layer_self_k_cache: Tensor,
            n_layer_self_v_cache: Tensor,
            n_layer_cross_k: Tensor,
            n_layer_cross_v: Tensor,
            offset: Tensor,
        ):
            x = (
                self.textDecoder.token_embedding(tokens)
                + self.textDecoder.positional_embedding[
                    offset[0] : offset[0] + tokens.shape[-1]
                ]
            )
            x = x.to(n_layer_cross_k[0].dtype)
    
            i = 0
            for block in self.blocks:
                self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
                self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
                x, self_k_cache, self_v_cache = block(
                    x,
                    self_k_cache=self_k_cache,
                    self_v_cache=self_v_cache,
                    cross_k=n_layer_cross_k[i],
                    cross_v=n_layer_cross_v[i],
                    mask=self.textDecoder.mask,
                )
                n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
                n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
                i += 1
    
            x = self.textDecoder.ln(x)
    
            if False:
                # x.shape (1, 3, 384)
                # weight.shape (51684, 384)
    
                logits = (
                    x
                    @ torch.transpose(
                        self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
                    )
                ).float()
            else:
                logits = (
                    torch.matmul(
                        self.textDecoder.token_embedding.weight.to(x.dtype),
                        x.permute(0, 2, 1),
                    )
                    .permute(0, 2, 1)
                    .float()
                )
    
            return logits, n_layer_self_k_cache, n_layer_self_v_cache
    该类是整个Whisper Decoder的顶层封装,管理多层Transformer块的执行流程:
    • 初始化时,将原始Decoder的所有Transformer块封装为ResidualAttentionBlockTensorCache实例
    • 接受tokens输入、自注意力KV缓存、预计算的跨注意力KV以及偏移量(offset)作为参数
    • 首先应用token嵌入和位置嵌入(注意使用offset来选择正确的位置嵌入)
    • 依次通过所有Transformer块处理,同时管理每层的KV缓存
    • 应用最终的层归一化
    • 计算logits输出(通过与token嵌入权重的矩阵乘法)
    • 返回logits和更新后的自注意力KV缓存
    注意代码中特殊的logits计算方法,修改了矩阵乘法的顺序以提高某些硬件上的性能。
  6. modified_audio_encoder_forward函数
    def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
    
        if False:
            # This branch contains the original code
            assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
            x = (x + self.positional_embedding).to(x.dtype)
        else:
            # This branch contains the actual changes
            assert (
                x.shape[2] == self.positional_embedding.shape[1]
            ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
            assert (
                x.shape[1] == self.positional_embedding.shape[0]
            ), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
            x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
    
        for block in self.blocks:
            x = block(x)
    
        x = self.ln_post(x)
        return x
    这个函数修改了Whisper AudioEncoder的原始forward方法,主要变化在于处理位置嵌入时支持动态长度:
    • 原始代码(if False分支)要求输入形状严格匹配位置嵌入形状
    • 修改后的代码(else分支)允许输入长度小于或等于预设的最大长度,通过切片self.positional_embedding[: x.shape[1]]来支持动态长度
    • 这一修改对于处理不同长度的音频输入至关重要,使模型能接受短于标准30秒的音频
    脚本通过AudioEncoder.forward = modified_audio_encoder_forward覆盖了原始方法,使这一改动全局生效。

2. 元数据处理与模型信息保存

add_meta_data函数用于向ONNX模型添加丰富的元数据,这些元数据对于模型的正确加载和使用至关重要:

def add_meta_data(filename: str, meta_data: Dict[str, Any]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)

    while len(model.metadata_props):
        model.metadata_props.pop()

    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    if "large" in filename or "turbo" in filename:
        external_filename = filename.split(".onnx")[0]
        onnx.save(
            model,
            filename,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=external_filename + ".weights",
        )
    else:
        onnx.save(model, filename)

核心功能包括:

3. Tokenizer处理

脚本中的convert_tokens函数负责提取和保存Whisper tokenizer的词汇表:

def convert_tokens(name, model):
    whisper_dir = Path(whisper.__file__).parent
    multilingual = model.is_multilingual
    tokenizer = (
        whisper_dir
        / "assets"
        / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
    )
    if not tokenizer.is_file():
        raise ValueError(f"Cannot find {tokenizer}")

    with open(tokenizer, "r") as f:
        contents = f.read()
        tokens = {
            token: int(rank)
            for token, rank in (line.split() for line in contents.splitlines() if line)
        }

    with open(f"{name}-tokens.txt", "w") as f:
        for t, i in tokens.items():
            f.write(f"{t} {i}\n")

该函数:

4. 主函数流程与模型导出

main()函数中,脚本实现了完整的模型加载、导出和量化流程:

  1. 模型加载与前处理
    args = get_args()
    name = args.model
    model = whisper.load_model(name)  # 或特殊模型的加载
    model.eval()
    print(model.dims)
    
  2. Tokenizer处理
    convert_tokens(name=name, model=model)
    tokenizer = whisper.tokenizer.get_tokenizer(
        model.is_multilingual, num_languages=model.num_languages
    )
    
  3. 示例输入准备
    audio = torch.rand(16000 * 2)
    audio = whisper.pad_or_trim(audio)
    assert audio.shape == (16000 * 30,), audio.shape
    
    if args.model in ("large", "large-v3", "turbo"):
        n_mels = 128
    else:
        n_mels = 80
    mel = (
        whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
    )
    batch_size = 1
    assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
    
  4. Encoder导出
    encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
    n_layer_cross_k, n_layer_cross_v = encoder(mel)
    # ... 验证形状 ...
    encoder_filename = f"{name}-encoder.onnx"
    torch.onnx.export(
        encoder,
        mel,
        encoder_filename,
        opset_version=opset_version,
        input_names=["mel"],
        output_names=["n_layer_cross_k", "n_layer_cross_v"],
        dynamic_axes={
            "mel": {0: "n_audio", 2: "T"},  # n_audio is also known as batch_size
            "n_layer_cross_k": {1: "n_audio", 2: "T"},
            "n_layer_cross_v": {1: "n_audio", 2: "T"},
        },
    )
    
  5. Encoder元数据添加
    encoder_meta_data = {
        "model_type": f"whisper-{name}",
        "version": "1",
        "maintainer": "k2-fsa",
        "n_mels": model.dims.n_mels,
        # ... 各种模型参数与维度信息 ...
        "is_multilingual": int(model.is_multilingual),
        # ... tokenizer相关信息 ...
    }
    add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
    
  6. Decoder导出
    tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
        mel.device
    )  # [n_audio, 3]
    decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
    n_layer_self_k_cache = torch.zeros(
        (
            len(model.decoder.blocks),
            n_audio,
            model.dims.n_text_ctx,
            model.dims.n_text_state,
        ),
        device=mel.device,
    )
    n_layer_self_v_cache = torch.zeros(
        # ... 同样的形状 ...
    )
    offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
    logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
        tokens,
        n_layer_self_k_cache,
        n_layer_self_v_cache,
        n_layer_cross_k,
        n_layer_cross_v,
        offset,
    )
    
    # 第二次推理示例 - 用于验证KV Cache更新
    offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
    tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
        tokens,
        n_layer_self_k_cache,
        n_layer_self_v_cache,
        n_layer_cross_k,
        n_layer_cross_v,
        offset,
    )
    
    # 实际导出
    decoder_filename = f"{name}-decoder.onnx"
    torch.onnx.export(
        decoder,
        (
            tokens,
            n_layer_self_k_cache,
            n_layer_self_v_cache,
            n_layer_cross_k,
            n_layer_cross_v,
            offset,
        ),
        decoder_filename,
        opset_version=opset_version,
        input_names=[
            "tokens",
            "in_n_layer_self_k_cache",
            "in_n_layer_self_v_cache",
            "n_layer_cross_k",
            "n_layer_cross_v",
            "offset",
        ],
        output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
        dynamic_axes={
            "tokens": {0: "n_audio", 1: "n_tokens"},
            "in_n_layer_self_k_cache": {1: "n_audio"},
            "in_n_layer_self_v_cache": {1: "n_audio"},
            "n_layer_cross_k": {1: "n_audio", 2: "T"},
            "n_layer_cross_v": {1: "n_audio", 2: "T"},
        },
    )
    
  7. 特殊处理large模型
    if "large" in args.model:
        decoder_external_filename = decoder_filename.split(".onnx")[0]
        decoder_model = onnx.load(decoder_filename)
        onnx.save(
            decoder_model,
            decoder_filename,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=decoder_external_filename + ".weights",
        )
    
  8. INT8量化
    print("Generate int8 quantization models")
    
    encoder_filename_int8 = f"{name}-encoder.int8.onnx"
    quantize_dynamic(
        model_input=encoder_filename,
        model_output=encoder_filename_int8,
        op_types_to_quantize=["MatMul"],
        weight_type=QuantType.QInt8,
    )
    
    decoder_filename_int8 = f"{name}-decoder.int8.onnx"
    quantize_dynamic(
        model_input=decoder_filename,
        model_output=decoder_filename_int8,
        op_types_to_quantize=["MatMul"],
        weight_type=QuantType.QInt8,
    )
    

5. 动态轴设置的详细分析

脚本中的动态轴配置是模型导出的关键部分,它定义了哪些维度允许在推理时变化:

6. INT8量化实现细节

脚本使用ONNX Runtime的量化API对模型进行后处理量化:

quantize_dynamic(
    model_input=encoder_filename,
    model_output=encoder_filename_int8,
    op_types_to_quantize=["MatMul"],
    weight_type=QuantType.QInt8,
)

关键特点:

7. Whisper模型转ONNX的特殊处理

从代码中可以观察到几个针对Whisper模型的特殊处理:

export-onnx.py 脚本的命令行参数示例与解读

根据get_args()函数,脚本接受以下命令行参数:

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        # fmt: off
        choices=[
            "tiny", "tiny.en", "base", "base.en",
            "small", "small.en", "medium", "medium.en",
            "large-v1", "large-v2",
            "large", "large-v3", "turbo", # these three have feature dim 128
            "distil-medium.en", "distil-small.en", "distil-large-v2",
            # "distil-large-v3", # distil-large-v3 is not supported!
            # for fine-tuned models from icefall
            "medium-aishell",
            ],
        # fmt: on
    )
    return parser.parse_args()

导出命令示例:

# 导出 tiny.en 模型
python ./scripts/whisper/export-onnx.py --model tiny.en

# 导出 large-v3 模型
python ./scripts/whisper/export-onnx.py --model large-v3

# 导出蒸馏版模型
python ./scripts/whisper/export-onnx.py --model distil-medium.en

# 导出微调后的模型
python ./scripts/whisper/export-onnx.py --model medium-aishell

脚本支持的模型类型非常丰富,包括:

特别注意,对于蒸馏模型和微调模型,脚本会检查相应文件是否存在,否则会给出下载指导。

输出文件解读

成功执行 export-onnx.py 脚本后,会在输出目录生成以下文件:

对于 large 系列模型,还会生成额外的权重文件:

这些输出文件的组织方式与sherpa-onnx的运行时预期完全匹配,可以直接被其识别和加载。

PyTorch 模型转 ONNX 的进阶技巧与常见问题

将 PyTorch 模型转换为 ONNX 格式,尤其对于复杂的 Transformer 模型,往往会遇到一些挑战。掌握以下进阶技巧和常见问题的处理方法至关重要。

处理动态输入形状 (Dynamic Shapes)

处理控制流 (Control Flow)

处理不支持的算子 (Unsupported Operators)

模型量化 (Quantization) 与导出

Transformer 模型 (如 Whisper) 导出特定难点

关键要点总结:高级技巧

  • 动态轴 (Dynamic Axes) 是处理可变序列长度模型的基石。
  • 对于控制流,优先考虑模型重构,或使用 torch.cond() 及 TorchDynamo 导出器。
  • 遇到不支持的算子,尝试升级 opset,修改模型,或在极端情况下考虑自定义算子。
  • 模型量化是性能优化的重要手段,需关注精度与硬件支持,ONNX Runtime 提供了强大工具。
  • Transformer 的 KV Cache 是导出时的主要复杂点,需精心设计输入输出和动态轴。

导出后验证与使用

成功将 PyTorch 模型导出为 ONNX 格式后,进行彻底的验证是确保模型正确性和可用性的关键步骤。之后,模型便可以在支持 ONNX 的环境中部署使用,例如通过 sherpa-onnx

模型可视化与检查

使用 ONNX Runtime 进行推理验证

这是验证导出模型数值准确性的核心步骤。目标是确保 ONNX 模型在给定相同输入时,其输出与原始 PyTorch 模型尽可能一致。

  1. 加载 ONNX 模型: 使用 ONNX Runtime 的 Python API 创建一个推理会话 (InferenceSession)。
    import onnxruntime as ort
    import numpy as np
    
    # 加载ONNX模型
    model_path = "your_model.onnx"
    session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) # 可指定其他EP,如CUDAExecutionProvider
    
    # 获取输入输出名称 (如果导出时未指定,Netron可查看)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
                        
  2. 准备输入数据:
    • 准备与 PyTorch 模型推理时使用的完全相同的输入数据。
    • 将输入数据转换为 NumPy 数组,并确保其数据类型 (e.g., np.float32, np.int64) 和形状与 ONNX 模型输入节点的要求一致。
    • 对于有动态轴的输入,确保提供的 NumPy 数组形状在该动态轴的有效范围内。
    # 假设dummy_input是之前用于导出或PyTorch推理的示例输入 (torch.Tensor)
    # 需要将其转换为NumPy数组
    dummy_input_np = dummy_input.cpu().numpy()
    input_feed = {input_name: dummy_input_np}
                        
  3. 执行推理: 调用会话的 run() 方法。
    # 执行推理
    # session.run() 的第一个参数是期望获取的输出节点名称列表
    # 第二个参数是以字典形式提供的输入数据,键为输入节点名,值为NumPy数组
    onnx_outputs = session.run([output_name], input_feed)
    onnx_output_np = onnx_outputs[0] # 通常返回一个列表,对应output_names列表
                        
  4. 与 PyTorch 模型输出对比:
    • 使用相同的输入数据,在原始 PyTorch 模型上执行一次推理。
    • 将 PyTorch 模型的输出(通常是 torch.Tensor)转换为 NumPy 数组。
    • 使用 numpy.allclose() 或类似函数比较 ONNX Runtime 的输出和 PyTorch 模型的输出,检查它们在数值上是否足够接近。由于浮点数计算的差异,完全相等可能很难达到,因此需要设置一个合理的容忍度 (atol, rtol)。
      # 获取PyTorch模型输出
      # model.eval()
      # with torch.no_grad():
      #     pytorch_output_tensor = model(dummy_input)
      # pytorch_output_np = pytorch_output_tensor.cpu().numpy()
      
      # 比较输出 (假设pytorch_output_np已准备好)
      # if np.allclose(pytorch_output_np, onnx_output_np, rtol=1e-03, atol=1e-05):
      #     print("ONNX model output matches PyTorch model output within tolerance.")
      # else:
      #     print("Output mismatch detected!")
      #     # 可以进一步打印差异的统计信息
      #     # print("Max absolute difference:", np.max(np.abs(pytorch_output_np - onnx_output_np)))
                                  

    PyTorch 官方的 ONNX 导出教程中也包含了对比 PyTorch 和 ONNX Runtime 结果的步骤。 PyTorch ONNX Tutorial: Compare Results

sherpa-onnx 中使用导出的模型

一旦 Whisper 模型被成功导出并验证为 ONNX 格式(包括 Encoder, Decoder 和 tokens.txt 文件),它们就可以被 sherpa-onnx 用来进行实际的语音识别任务。

通过这种方式,sherpa-onnx 充分利用了 ONNX 的跨平台和高性能特性,使得强大的 Whisper 模型能够在各种设备上高效运行。

总结与最佳实践

将 PyTorch 模型转换为 ONNX 格式是实现模型跨平台部署和性能优化的关键步骤。通过以 sherpa-onnx/scripts/whisper/export-onnx.py 为例的分析,我们深入了解了这一过程的复杂性和关键技术点。

核心流程回顾

  1. 准备 PyTorch 模型: 确保模型处于评估模式 (model.eval()),以保证 Dropout、BatchNorm 等层的行为正确。
  2. 构造正确的示例输入 (dummy_input): 示例输入的形状、数据类型和数量必须与模型实际前向传播时所需的一致。这是 ONNX 导出器追踪计算图的基础。
  3. 调用 torch.onnx.export: 这是核心导出函数。需要仔细设置关键参数,特别是:
    • opset_version: 影响算子支持和兼容性。
    • input_names, output_names: 提高模型的可读性和易用性。
    • dynamic_axes: 对于处理可变长度输入(如文本、音频)至关重要。
  4. 组件化导出 (针对复杂模型): 对于像 Whisper 这样由多个独立组件(如 Encoder 和 Decoder)构成的复杂模型,通常需要将每个组件分别导出为独立的 ONNX 模型。这更符合实际的推理流程,并有助于管理复杂性。
  5. (可选) 进行 INT8 量化: 为了进一步优化性能和减小模型体积,可以在导出 FP32 ONNX 模型后,使用 ONNX Runtime 等工具将其转换为 INT8 量化模型。这通常需要校准数据。
  6. 验证模型: 使用 Netron 可视化检查模型结构,并使用 ONNX Runtime 执行推理,将其输出与原始 PyTorch 模型输出进行对比,确保数值准确性。

最佳实践建议

参考资料