Skip to content

onnx infer #24

@yxl502

Description

@yxl502

import torch
import torchvision.models as models
import sys
import os
current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tubevit.model import TubeViTLightningModule

加载一个预训练的PyTorch模型

model = TubeViTLightningModule(
num_classes=3,
video_shape=[3, 1, 448, 224],
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt",
test_each_epoch = False
)
model.eval()

定义模型的输入示例

dummy_input = torch.randn(1, 3, 1, 448, 224)

指定要保存的ONNX文件的路径

onnx_file_path = "./weights/test.onnx"

导出模型到ONNX格式

torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions