作者:杨志
在当前的机器学习领域,模型的研发与部署是两个紧密相连却又各具挑战的环节。PyTorch 以其灵活性、易用性和强大的社区支持,成为学术研究和模型训练的首选框架之一。然而,当模型从研究阶段走向实际应用部署时,开发者往往面临诸多挑战。
PyTorch:作为一款以 Python 优先,支持动态计算图的深度学习框架,PyTorch 极大地简化了复杂模型的构建和快速迭代过程。其完善的生态系统和丰富的预训练模型库,使其在科研领域占据主导地位。 PyTorch官方网站
模型部署的挑战:训练完成的 PyTorch 模型(通常为 .pth
或 .pt
文件)在部署时会遇到以下问题:
ONNX (Open Neural Network Exchange) 的诞生:为了解决上述模型互操作性问题,ONNX 应运而生。它是一种为机器学习模型设计的开放标准格式。 ONNX官方网站
ONNX 的核心价值在于实现"一次训练,多处部署"。它充当了不同深度学习框架和硬件加速器之间的桥梁,允许开发者将一个框架中训练的模型导出为 ONNX 格式,然后在另一个支持 ONNX 的框架或硬件上进行推理。 PyTorch ONNX导出教程
sherpa-onnx/scripts/whisper/export-onnx.py
脚本扮演了怎样的角色?其具体的实现方法和关键考量是什么?本文旨在以 sherpa-onnx
项目中用于导出 Whisper 模型的 export-onnx.py
脚本为例,深入剖析从 PyTorch 模型到 ONNX 格式的转换流程。我们将详细探讨其中的关键技术点、参数选择、常见问题及其解决方案。通过阅读本文,读者不仅能理解通用的模型转换原理,还能掌握针对特定复杂模型(如语音识别领域的 Transformer 模型)的导出技巧和最佳实践。
文章结构将围绕以下几个核心部分展开:ONNX 与 PyTorch torch.onnx.export
函数概览、sherpa-onnx
Whisper 导出脚本的深度解读、PyTorch 模型转 ONNX 的进阶技巧与常见问题、导出后验证与使用,最后进行总结并提供最佳实践建议。
torch.onnx.export
概览定义:ONNX (Open Neural Network Exchange) 是一种用于表示机器学习模型的开放格式。它旨在促进不同人工智能框架之间的互操作性。 ONNX简介
组成:一个 ONNX 模型主要由以下部分构成:
优势:
ONNX Runtime:由微软开发并开源的高性能推理引擎,专门用于执行 ONNX 模型。它支持多种硬件加速器和操作系统,能够显著提升模型在CPU和GPU上的推理速度。 ONNX Runtime官方网站
torch.onnx.export
函数详解PyTorch 通过 torch.onnx
模块提供将模型导出为 ONNX 格式的功能,其核心函数是 torch.onnx.export()
。
核心功能:该函数通过执行一次模型(使用提供的示例输入),记录下计算过程中涉及的 PyTorch 算子操作轨迹,然后将这些轨迹转换为符合 ONNX规范的计算图。 torch.onnx官方文档
model
(torch.nn.Module
): 需要导出的 PyTorch 模型实例。args
(tuple): 模型的示例输入,也称为 "dummy input"。其形状和类型应与模型实际期望的输入一致。对于有多个输入的模型,args
应为一个元组,包含所有输入。f
(str or file-like object): 导出的 ONNX 模型将保存到的文件路径或一个可写的类文件对象。export_params
(bool, default=True): 是否将模型中训练好的参数(权重和偏置)一并导出。通常应保持为 True
。opset_version
(int): 指定要使用的 ONNX 算子集版本。不同的版本支持不同的算子集和行为。选择合适的 opset_version
对于确保模型兼容性和算子支持至关重要。例如,opset_version=11
或 opset_version=17
是常用的选择。do_constant_folding
(bool, default=True): 是否执行常量折叠优化。常量折叠会在导出过程中预计算那些输入为常量的节点,从而简化图结构并可能提升性能。建议在部署时开启。input_names
(list of str, optional): 为 ONNX 模型的输入节点指定名称。例如 ['input_ids', 'attention_mask']
。这有助于后续使用 ONNX Runtime 推理时按名称提供输入。output_names
(list of str, optional): 为 ONNX 模型的输出节点指定名称。例如 ['logits']
。dynamic_axes
(dict, optional): 指定输入/输出张量中哪些维度是动态的(可变的)。这对于处理如可变序列长度(NLP, ASR)、可变批次大小等场景至关重要。
{'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 'output_logits': {0: 'batch_size'}}
表示 input_ids
的第0维是动态的(名为 'batch_size'),第1维也是动态的(名为 'sequence_length'),而 output_logits
的第0维是动态的。verbose
(bool, default=False): 是否在导出过程中打印详细的日志信息,有助于调试。PyTorch 的 ONNX 导出器经历了发展,主要有两个版本:
torch.jit.trace()
或 torch.jit.script()
)来捕获模型图。
torch.jit.trace()
:通过执行一次模型来记录操作。对于包含数据依赖控制流(如依赖张量值的 if
语句)的模型,trace 模式可能无法正确捕捉所有路径。torch.jit.script()
:尝试直接解析 Python 代码来理解模型结构,对控制流有更好的支持,但要求代码符合 TorchScript 的语法子集。torch.onnx.export(..., dynamo=True)
(PyTorch 2.5+) 或早期的 torch.onnx.dynamo_export()
(已不推荐) 来使用。
一个至关重要的步骤是在调用 torch.onnx.export()
之前,将模型设置为评估模式:
model.eval()
或者等效的 model.train(False)
。这是因为像 Dropout
层和 BatchNorm
层这样的模块在训练模式和评估模式下的行为是不同的。在导出用于推理的模型时,必须确保它们处于评估(推理)状态,否则可能导致导出的 ONNX 模型行为不正确或性能不佳。 微软PyTorch转ONNX教程强调model.eval()
export-onnx.py
脚本解读以一个将 Whisper 模型从 PyTorch 实现转换为 ONNX 格式的脚本 export-onnx.py
为例,我们可以深入了解复杂 Transformer 模型导出的具体实现细节。这个脚本来自开源项目,是理解复杂模型导出流程的典型案例。
本案例中的脚本主要处理 Whisper 模型,以下是相关背景信息:
Whisper 模型: 由 OpenAI 开发的基于 Transformer 架构的多语言语音识别模型。该模型在大量弱监督数据上训练,具有较高的识别准确率和跨语言性能。 OpenAI Whisper介绍
Esperanto Technologies 在一篇博客中详细讨论了 Whisper 模型的架构及其到 ONNX 的转换,指出其遵循典型的 Transformer 结构。 Adapting Whisper Models to ET-SoC-1 Architecture and Exporting Them to ONNX
export-onnx.py
脚本目标与设计分析的 export-onnx.py
脚本的核心目标是将 Whisper 模型的 PyTorch 版本转换为能被 ONNX 运行时加载和使用的模型文件。
核心设计理念:
-encoder.onnx
和 -decoder.onnx
。这种分离设计更符合实际推理流程,并有助于优化。-tokens.txt
文件,其中包含了 Whisper 模型使用的 Tokenizer 的词汇表。这个文件对于将模型输出的 Token ID 转换回可读文本至关重要。tiny.en
, base
, small
, medium
, large
, large-v1
, large-v2
, large-v3
等。这些预训练模型可以直接从 Hugging Face Hub 下载。 sherpa-onnx Whisper模型导出文档-encoder.int8.onnx
的文件。根据提供的完整export-onnx.py
源码,我们可以系统分析其核心实现和关键技术点。
为了适应ONNX导出,脚本对Whisper的原始实现进行了多项修改和封装。以下是几个关键类的详细分析:
AudioEncoderTensorCache
类
class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
super().__init__()
self.audioEncoder = inAudioEncoder
self.textDecoder = inTextDecoder
def forward(self, x: Tensor):
audio_features = self.audioEncoder(x)
n_layer_cross_k_list = []
n_layer_cross_v_list = []
for block in self.textDecoder.blocks:
n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
该类将Whisper的Encoder组件封装为能够提前生成cross-attention所需key和value的模块。其核心功能是:
audioEncoder
处理音频梅尔频谱图输入,得到audio_features
MultiHeadAttentionCross
类
class MultiHeadAttentionCross(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
):
q = self.multiHeadAttention.query(x)
wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
return self.multiHeadAttention.out(wv)
该类专门处理Decoder中的跨注意力机制,使其接口适应ONNX导出需求:
MultiHeadAttentionSelf
类
class MultiHeadAttentionSelf(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor, # (b, n_ctx , n_state)
k_cache: Tensor, # (b, n_ctx_cache, n_state)
v_cache: Tensor, # (b, n_ctx_cache, n_state)
mask: Tensor,
):
q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state)
k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state)
v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state)
k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state)
v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state)
wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
return self.multiHeadAttention.out(wv), k_cache, v_cache
该类专门处理Decoder中的自注意力机制,并实现了KV缓存功能:
ResidualAttentionBlockTensorCache
类
class ResidualAttentionBlockTensorCache(nn.Module):
def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
super().__init__()
self.originalBlock = inResidualAttentionBlock
self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
self.cross_attn = (
MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
if inResidualAttentionBlock.cross_attn
else None
)
def forward(
self,
x: Tensor,
self_k_cache: Tensor,
self_v_cache: Tensor,
cross_k: Tensor,
cross_v: Tensor,
mask: Tensor,
):
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
)
x = x + self_attn_x
if self.cross_attn:
x = x + self.cross_attn(
self.originalBlock.cross_attn_ln(x), cross_k, cross_v
)
x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
return x, self_k_cache_updated, self_v_cache_updated
该类封装了Transformer Decoder块的完整功能,包括自注意力、跨注意力和前馈网络:
MultiHeadAttentionSelf
和MultiHeadAttentionCross
类来处理自注意力和跨注意力TextDecoderTensorCache
类
class TextDecoderTensorCache(nn.Module):
def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
super().__init__()
self.textDecoder = inTextDecoder
self.n_ctx = in_n_ctx
self.blocks = []
for orginal_block in self.textDecoder.blocks:
self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
def forward(
self,
tokens: Tensor,
n_layer_self_k_cache: Tensor,
n_layer_self_v_cache: Tensor,
n_layer_cross_k: Tensor,
n_layer_cross_v: Tensor,
offset: Tensor,
):
x = (
self.textDecoder.token_embedding(tokens)
+ self.textDecoder.positional_embedding[
offset[0] : offset[0] + tokens.shape[-1]
]
)
x = x.to(n_layer_cross_k[0].dtype)
i = 0
for block in self.blocks:
self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
x, self_k_cache, self_v_cache = block(
x,
self_k_cache=self_k_cache,
self_v_cache=self_v_cache,
cross_k=n_layer_cross_k[i],
cross_v=n_layer_cross_v[i],
mask=self.textDecoder.mask,
)
n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
i += 1
x = self.textDecoder.ln(x)
if False:
# x.shape (1, 3, 384)
# weight.shape (51684, 384)
logits = (
x
@ torch.transpose(
self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
)
).float()
else:
logits = (
torch.matmul(
self.textDecoder.token_embedding.weight.to(x.dtype),
x.permute(0, 2, 1),
)
.permute(0, 2, 1)
.float()
)
return logits, n_layer_self_k_cache, n_layer_self_v_cache
该类是整个Whisper Decoder的顶层封装,管理多层Transformer块的执行流程:
ResidualAttentionBlockTensorCache
实例modified_audio_encoder_forward
函数
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
if False:
# This branch contains the original code
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
else:
# This branch contains the actual changes
assert (
x.shape[2] == self.positional_embedding.shape[1]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
assert (
x.shape[1] == self.positional_embedding.shape[0]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
这个函数修改了Whisper AudioEncoder的原始forward方法,主要变化在于处理位置嵌入时支持动态长度:
self.positional_embedding[: x.shape[1]]
来支持动态长度AudioEncoder.forward = modified_audio_encoder_forward
覆盖了原始方法,使这一改动全局生效。
add_meta_data
函数用于向ONNX模型添加丰富的元数据,这些元数据对于模型的正确加载和使用至关重要:
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
if "large" in filename or "turbo" in filename:
external_filename = filename.split(".onnx")[0]
onnx.save(
model,
filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_filename + ".weights",
)
else:
onnx.save(model, filename)
核心功能包括:
.weights
),解决ONNX文件大小限制问题main()
函数中可见,添加的元数据十分丰富,包括模型类型、维度信息、tokenizer配置等关键信息脚本中的convert_tokens
函数负责提取和保存Whisper tokenizer的词汇表:
def convert_tokens(name, model):
whisper_dir = Path(whisper.__file__).parent
multilingual = model.is_multilingual
tokenizer = (
whisper_dir
/ "assets"
/ (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
)
if not tokenizer.is_file():
raise ValueError(f"Cannot find {tokenizer}")
with open(tokenizer, "r") as f:
contents = f.read()
tokens = {
token: int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
with open(f"{name}-tokens.txt", "w") as f:
for t, i in tokens.items():
f.write(f"{t} {i}\n")
该函数:
{name}-tokens.txt
文件,供sherpa-onnx运行时使用在main()
函数中,脚本实现了完整的模型加载、导出和量化流程:
args = get_args()
name = args.model
model = whisper.load_model(name) # 或特殊模型的加载
model.eval()
print(model.dims)
convert_tokens(name=name, model=model)
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
audio = torch.rand(16000 * 2)
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30,), audio.shape
if args.model in ("large", "large-v3", "turbo"):
n_mels = 128
else:
n_mels = 80
mel = (
whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
)
batch_size = 1
assert mel.shape == (batch_size, n_mels, 30 * 100), mel.shape
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
n_layer_cross_k, n_layer_cross_v = encoder(mel)
# ... 验证形状 ...
encoder_filename = f"{name}-encoder.onnx"
torch.onnx.export(
encoder,
mel,
encoder_filename,
opset_version=opset_version,
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)
encoder_meta_data = {
"model_type": f"whisper-{name}",
"version": "1",
"maintainer": "k2-fsa",
"n_mels": model.dims.n_mels,
# ... 各种模型参数与维度信息 ...
"is_multilingual": int(model.is_multilingual),
# ... tokenizer相关信息 ...
}
add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
mel.device
) # [n_audio, 3]
decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
n_layer_self_k_cache = torch.zeros(
(
len(model.decoder.blocks),
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
),
device=mel.device,
)
n_layer_self_v_cache = torch.zeros(
# ... 同样的形状 ...
)
offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
# 第二次推理示例 - 用于验证KV Cache更新
offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
# 实际导出
decoder_filename = f"{name}-decoder.onnx"
torch.onnx.export(
decoder,
(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
),
decoder_filename,
opset_version=opset_version,
input_names=[
"tokens",
"in_n_layer_self_k_cache",
"in_n_layer_self_v_cache",
"n_layer_cross_k",
"n_layer_cross_v",
"offset",
],
output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
dynamic_axes={
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
)
if "large" in args.model:
decoder_external_filename = decoder_filename.split(".onnx")[0]
decoder_model = onnx.load(decoder_filename)
onnx.save(
decoder_model,
decoder_filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=decoder_external_filename + ".weights",
)
print("Generate int8 quantization models")
encoder_filename_int8 = f"{name}-encoder.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = f"{name}-decoder.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
脚本中的动态轴配置是模型导出的关键部分,它定义了哪些维度允许在推理时变化:
dynamic_axes={
"mel": {0: "n_audio", 2: "T"}, # n_audio is also known as batch_size
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
mel
输入:批次大小(维度0)和时间轴/音频长度(维度2)允许动态变化n_layer_cross_k
和n_layer_cross_v
输出:这些是预计算的cross-attention keys和values,它们的批次大小(维度1)和时间轴长度(维度2)需要与输入保持匹配dynamic_axes={
"tokens": {0: "n_audio", 1: "n_tokens"},
"in_n_layer_self_k_cache": {1: "n_audio"},
"in_n_layer_self_v_cache": {1: "n_audio"},
"n_layer_cross_k": {1: "n_audio", 2: "T"},
"n_layer_cross_v": {1: "n_audio", 2: "T"},
},
tokens
:允许批次大小(维度0)和token序列长度(维度1)动态变化脚本使用ONNX Runtime的量化API对模型进行后处理量化:
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
关键特点:
MatMul
操作进行量化,这是Transformer模型中计算和内存开销最大的部分QInt8
,这是一个带符号的8位整数类型,适合权重的分布特性从代码中可以观察到几个针对Whisper模型的特殊处理:
TextDecoderTensorCache.forward
中,通过调整矩阵乘法的顺序提高计算效率:
if False:
# x.shape (1, 3, 384)
# weight.shape (51684, 384)
logits = (
x
@ torch.transpose(
self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1
)
).float()
else:
logits = (
torch.matmul(
self.textDecoder.token_embedding.weight.to(x.dtype),
x.permute(0, 2, 1),
)
.permute(0, 2, 1)
.float()
)
export-onnx.py
脚本的命令行参数示例与解读根据get_args()
函数,脚本接受以下命令行参数:
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
# fmt: off
choices=[
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large-v1", "large-v2",
"large", "large-v3", "turbo", # these three have feature dim 128
"distil-medium.en", "distil-small.en", "distil-large-v2",
# "distil-large-v3", # distil-large-v3 is not supported!
# for fine-tuned models from icefall
"medium-aishell",
],
# fmt: on
)
return parser.parse_args()
导出命令示例:
# 导出 tiny.en 模型
python ./scripts/whisper/export-onnx.py --model tiny.en
# 导出 large-v3 模型
python ./scripts/whisper/export-onnx.py --model large-v3
# 导出蒸馏版模型
python ./scripts/whisper/export-onnx.py --model distil-medium.en
# 导出微调后的模型
python ./scripts/whisper/export-onnx.py --model medium-aishell
脚本支持的模型类型非常丰富,包括:
特别注意,对于蒸馏模型和微调模型,脚本会检查相应文件是否存在,否则会给出下载指导。
成功执行 export-onnx.py
脚本后,会在输出目录生成以下文件:
<model_name>-encoder.onnx
: Whisper Encoder 的 FP32 ONNX 模型。<model_name>-decoder.onnx
: Whisper Decoder 的 FP32 ONNX 模型,包含 KV Cache 处理逻辑。<model_name>-tokens.txt
: Tokenizer 的词汇表文件。<model_name>-encoder.int8.onnx
: INT8 量化后的 Encoder ONNX 模型。<model_name>-decoder.int8.onnx
: INT8 量化后的 Decoder ONNX 模型。对于 large
系列模型,还会生成额外的权重文件:
<model_name>-encoder.weights
: Encoder 的外部权重文件。<model_name>-decoder.weights
: Decoder 的外部权重文件。这些输出文件的组织方式与sherpa-onnx
的运行时预期完全匹配,可以直接被其识别和加载。
将 PyTorch 模型转换为 ONNX 格式,尤其对于复杂的 Transformer 模型,往往会遇到一些挑战。掌握以下进阶技巧和常见问题的处理方法至关重要。
torch.onnx.export
函数的 dynamic_axes
参数。这个参数是一个字典,键是输入/输出节点的名称,值是另一个字典,指定了哪些轴是动态的,并可以为这些动态轴命名。
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'}, # input_ids的第0轴叫batch_size, 第1轴叫sequence_length
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'batch_size'} # 输出logits的第0轴叫batch_size
}
torch.onnx.export(model, dummy_inputs, "model.onnx", ..., dynamic_axes=dynamic_axes)
view()
, reshape()
, einops.rearrange
)在存在动态维度时,ONNX 导出器可能难以正确推断输出形状,导致转换失败或运行时错误。torch.nn.MultiheadAttention
这样的复杂模块,在处理动态序列长度和 KV Cache 时,其 ONNX 导出尤其具有挑战性。GitHub 上有相关讨论指出特定 PyTorch 版本或导出器(如 TorchDynamo-based)可能对这类问题有更好的处理。例如,一个 PyTorch GitHub issue (#120075) 讨论了 MultiheadAttention
在动态形状导出时遇到的问题,并建议尝试使用 dynamo=True
导出器或更新 PyTorch 版本。 PyTorch Issue #120075: MultiheadAttention dynamic shapesif/else
条件判断、for/while
循环)如果其条件或迭代次数依赖于模型中间张量的值(即数据依赖的控制流),传统的基于追踪的 ONNX 导出器 (Trace-based TorchScript) 很难捕捉这种动态行为。追踪只会记录示例输入执行的那一条路径。torch.cond()
(PyTorch 2.x+): 对于条件分支,PyTorch 提供了 torch.cond()
运算符,它可以在 FX Graph 中表示条件逻辑,并且可以被 TorchDynamo-based 导出器转换为 ONNX 的 If
算子。这要求将原始的 if/else
逻辑封装到两个独立的函数(分别对应 true 和 false 分支)中,并作为参数传递给 torch.cond()
。 PyTorch教程:导出带控制流的模型到ONNXdynamo=True
): PyTorch 2.x 引入的基于 TorchDynamo 的导出器对 Python 原生控制流具有更好的支持,因为它通过字节码分析来捕获图,能更准确地理解和转换一些动态行为。然而,它仍然有其局限性,并非所有 Python 控制流都能完美转换。 TorchDynamo-based ONNX Exportertorch.onnx.export()
时,可能会遇到错误,提示某个 PyTorch 算子 (operator) 没有对应的 ONNX 实现,或者当前选择的 opset_version
不支持该算子。opset_version
: 较新的 ONNX 算子集版本通常会支持更多的 PyTorch 算子。尝试递增 opset_version
(e.g., 从 11 到 13, 13 到 17) 看是否能解决问题。但要注意目标推理引擎对高版本 opset 的支持情况。torch.fx
进行图改写 (Advanced): 在模型导出到 ONNX 之前,可以使用 PyTorch FX (torch.fx
) 对模型的计算图进行编程方式的修改,例如替换或分解不支持的算子模式。sherpa-onnx
的 export-onnx.py
脚本示例中,生成的量化模型如 -encoder.int8.onnx
,这通常是在导出 FP32 ONNX 模型之后,再利用 ONNX Runtime 提供的量化工具(如 onnxruntime.quantization
Python API)进行转换得到的。 sherpa-onnx Whisper INT8 modelstorch.quantization.convert
,然后再调用 torch.onnx.export
。QLinearConv
, MatMulInteger
等)。nn.MultiheadAttention
和 KV Cache:
-decoder.onnx
)通常只代表单步的 Decoder 计算:给定当前 Tokens、Encoder 输出和过去的 KV Cache,预测下一个 Token 的 Logits 并输出新的 KV Cache。sherpa-onnx
的 export-onnx.py
脚本正是遵循这种设计,将 Decoder 导出为一个单步计算模块。其运行时环境负责管理解码循环和 KV Cache 的传递。dynamic_axes
参数使导出的 ONNX Encoder 模型能够接受可变长度的音频输入(在一定范围内),这样可以减少预处理的复杂性,并更灵活地适应不同时长的音频。torch.cond()
及 TorchDynamo 导出器。成功将 PyTorch 模型导出为 ONNX 格式后,进行彻底的验证是确保模型正确性和可用性的关键步骤。之后,模型便可以在支持 ONNX 的环境中部署使用,例如通过 sherpa-onnx
。
.onnx
文件并展示模型的计算图结构、节点属性、输入输出信息等。 Netron官方网站
dynamic_axes
是否已正确应用?domain
和 op_version
是否与目标 ONNX Runtime 环境兼容。这是验证导出模型数值准确性的核心步骤。目标是确保 ONNX 模型在给定相同输入时,其输出与原始 PyTorch 模型尽可能一致。
import onnxruntime as ort
import numpy as np
# 加载ONNX模型
model_path = "your_model.onnx"
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) # 可指定其他EP,如CUDAExecutionProvider
# 获取输入输出名称 (如果导出时未指定,Netron可查看)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
np.float32
, np.int64
) 和形状与 ONNX 模型输入节点的要求一致。# 假设dummy_input是之前用于导出或PyTorch推理的示例输入 (torch.Tensor)
# 需要将其转换为NumPy数组
dummy_input_np = dummy_input.cpu().numpy()
input_feed = {input_name: dummy_input_np}
run()
方法。
# 执行推理
# session.run() 的第一个参数是期望获取的输出节点名称列表
# 第二个参数是以字典形式提供的输入数据,键为输入节点名,值为NumPy数组
onnx_outputs = session.run([output_name], input_feed)
onnx_output_np = onnx_outputs[0] # 通常返回一个列表,对应output_names列表
torch.Tensor
)转换为 NumPy 数组。numpy.allclose()
或类似函数比较 ONNX Runtime 的输出和 PyTorch 模型的输出,检查它们在数值上是否足够接近。由于浮点数计算的差异,完全相等可能很难达到,因此需要设置一个合理的容忍度 (atol
, rtol
)。
# 获取PyTorch模型输出
# model.eval()
# with torch.no_grad():
# pytorch_output_tensor = model(dummy_input)
# pytorch_output_np = pytorch_output_tensor.cpu().numpy()
# 比较输出 (假设pytorch_output_np已准备好)
# if np.allclose(pytorch_output_np, onnx_output_np, rtol=1e-03, atol=1e-05):
# print("ONNX model output matches PyTorch model output within tolerance.")
# else:
# print("Output mismatch detected!")
# # 可以进一步打印差异的统计信息
# # print("Max absolute difference:", np.max(np.abs(pytorch_output_np - onnx_output_np)))
PyTorch 官方的 ONNX 导出教程中也包含了对比 PyTorch 和 ONNX Runtime 结果的步骤。 PyTorch ONNX Tutorial: Compare Results
sherpa-onnx
中使用导出的模型一旦 Whisper 模型被成功导出并验证为 ONNX 格式(包括 Encoder, Decoder 和 tokens.txt
文件),它们就可以被 sherpa-onnx
用来进行实际的语音识别任务。
sherpa-onnx
提供了 C++ 和 Python API,允许开发者加载这些 ONNX 模型文件。sherpa-onnx
的运行时会负责整个语音识别流水线:
sherpa-onnx
会管理一个解码循环,在循环的每一步:
tokens.txt
文件转换为最终的文本转录。sherpa-onnx
的文档和示例代码会展示如何配置和使用这些导出的 ONNX 模型文件。例如,其 Python API 通常需要指定 Encoder、Decoder 和 Tokens 文件的路径。 sherpa-onnx Documentation通过这种方式,sherpa-onnx
充分利用了 ONNX 的跨平台和高性能特性,使得强大的 Whisper 模型能够在各种设备上高效运行。
将 PyTorch 模型转换为 ONNX 格式是实现模型跨平台部署和性能优化的关键步骤。通过以 sherpa-onnx/scripts/whisper/export-onnx.py
为例的分析,我们深入了解了这一过程的复杂性和关键技术点。
model.eval()
),以保证 Dropout、BatchNorm 等层的行为正确。dummy_input
): 示例输入的形状、数据类型和数量必须与模型实际前向传播时所需的一致。这是 ONNX 导出器追踪计算图的基础。torch.onnx.export
: 这是核心导出函数。需要仔细设置关键参数,特别是:
opset_version
: 影响算子支持和兼容性。input_names
, output_names
: 提高模型的可读性和易用性。dynamic_axes
: 对于处理可变长度输入(如文本、音频)至关重要。opset_version
: 在开始导出前,调研目标推理环境(如特定版本的 ONNX Runtime、硬件SDK)支持的最高 opset_version
,并选择一个既能满足模型算子需求又被广泛支持的版本。通常,较新的版本支持更多算子,但兼容性可能稍差。model.eval()
: 这是最基本也是最容易被忽略的一点,不正确的模式会导致模型行为错误。input_names
, output_names
, dynamic_axes
: 清晰地命名输入输出节点,并准确定义动态轴,能极大地方便后续的模型集成和使用,避免因维度不匹配或节点名称混淆导致的问题。k2-fsa/sherpa-onnx
)也是解决疑难杂症的宝贵资源。dynamo=True
): 对于使用 PyTorch 2.x 及以上版本的用户,特别是处理包含复杂 Python 动态特性或控制流的模型时,新的 TorchDynamo 导出器(通过 torch.onnx.export(..., dynamo=True)
启用)通常能提供更好的支持和更准确的转换结果。 TorchDynamo-based ONNX Exporteropset_version
。这些版本之间的兼容性有时非常敏感。torch.onnx.export
打印的任何警告信息,它们可能预示着潜在的问题或不兼容性。