Giter Site home page Giter Site logo

Comments (2)

hayden-donnelly avatar hayden-donnelly commented on August 9, 2024 1

Not sure if Tensorflow savedmodel is any different, but I use this code to convert my TF savedmodels to ONNX:

import tensorflow as tf 
import tf2onnx

input_path = "../data/models/"
model_name = "simple_upsampler_bilinear"
file_type = ''
output_path = "../data/models/onnx_models/"

#load the model.
pre_model = tf.keras.models.load_model(input_path + model_name + file_type)

# Convert h5 to onnx.
tf2onnx.convert.from_keras(pre_model, output_path = output_path + model_name + ".onnx", opset = 9)

from barracuda-release.

dilne avatar dilne commented on August 9, 2024

ONNX is required. You can convert a PyTorch saved model to ONNX like shown below. I include the ONNX checker, it should confirm everything worked:

import torch
import torch.onnx
import onnx
import numpy as np

# load path
load_path = r"your/model/directory/name_of_model.pt"

# output path
output_path = r"your/output/directory/name_of_model.onnx"

model = torch.load(load_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Input to the model (you need to define the size for the model inputs)
# This example is for a single batch, single channel tensor of size 28 x 28)
x = torch.randn(1, 1, 28, 28, requires_grad=True).to(device)
torch_out = model(x)

# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  output_path,               # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,           # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  )

onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)

from barracuda-release.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.