-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
import torch
import torchvision.models as models
import sys
import os
current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tubevit.model import TubeViTLightningModule
加载一个预训练的PyTorch模型
model = TubeViTLightningModule(
num_classes=3,
video_shape=[3, 1, 448, 224],
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt",
test_each_epoch = False
)
model.eval()
定义模型的输入示例
dummy_input = torch.randn(1, 3, 1, 448, 224)
指定要保存的ONNX文件的路径
onnx_file_path = "./weights/test.onnx"
导出模型到ONNX格式
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持
Metadata
Metadata
Assignees
Labels
No labels