The Workload Generator is a project that accepts a PyTorch model and a test input set. It generates an ONNX file and analyzes the types of operators inside the ONNX. It also provides approximations for MAC counts, the number of parameters, input/output sizes, and more in a report object.
-
Clone the repository:
git clone https://github.com/idslab-skku/WorkloadGenerator.git
-
Create a new conda environment and activate it:
conda env create -f environment.yml conda activate workload-generator
-
Import the necessary modules:
import torch import onnx from WOrkloadGenerator import WorkloadGenerator
-
Load your PyTorch model and test input set:
model = torch.load('path/to/your/model.pth') test_input = torch.load('path/to/your/test_input.pth')
-
Generate the ONNX file:
generator = WorkloadGenerator(model) onnx_file = generator.generate_onnx('path/to/save/onnx/model.onnx')
-
Analyze the ONNX file and generate the report:
report = generator.analyze_onnx(onnx_file)
-
Access the information in the report object:
mac_counts = report.mac_counts num_parameters = report.num_parameters input_size = report.input_size output_size = report.output_size # ... and more
This project is licensed under the MIT License.