Building a Minimal IR for ONNX Model Compilation: From Concept to Implementation
Introduction: The Role of Intermediate Representations
Our compiler's Intermediate Representation (IR) serves as the crucial bridge between high-level ONNX models and optimized executable code. This post documents my implementation of a minimal Relay-inspired IR and the ONNX translation pipeline, serving as an initial step for my tvm compiler project.
Core Implementation Components
1. Operator Registry: The Conversion Backbone
Purpose: Central mapping of ONNX operators to conversion function
class OperatorRegistry {
private:
std::unordered_map<std::string, ConversionFunc> registry;
static OperatorRegistry* _operator_registry_instance;
void register_all_ops();
public:
OperatorRegistry() {
register_all_ops();
}
static OperatorRegistry* get_instance();
void registerOp(const std::string& optype, ConversionFunc func){
registry[optype] = std::move(func);
}
ConversionFunc getConversionFunc(const std::string& optype){
auto it = registry.find(optype);
if(it != registry.end()){
return it->second;
}
throw std::runtime_error("Operator not found in registry: " + optype);
}
std::vector<std::shared_ptr<RelayExpr>> convertNode(onnx::NodeProto& node, const std::vector<std::shared_ptr<RelayExpr>>& inputs){
auto it = registry.find(node.op_type());
if(it != registry.end()){
return it->second(node, inputs);
}
throw std::runtime_error("Operator not found in registry: " + node.op_type());
}
};
The operator registry serves as the central dispatch mechanism for ONNX-to-Relay conversions, implementing a classic factory pattern. At its core is a mapping that translates ONNX operator type strings (like "Gemm") to conversion functions. These functions accept an ONNX node prototype and input tensors, returning constructed Relay expressions. The singleton pattern ensures a single global registry while maintaining testability through implementations.
Key design aspects:
1. Extensibility: New operators can be added via simple registerOp calls
2. Type Safety: Conversion functions enforce strict signature matching
This design enables clean separation between operator definitions and graph traversal logic, critical for maintaining a modular codebase as we expand supported operators.
2. Type System Implementation
The core type system implements a hierarchy of type representations essential for validating neural network operations:
Base Type Class
class Type {
public:
virtual void print(std::ostream& os) {
os << "Not implemented print for Type";
exit(1); // Enforces implementation in derived classes
}
};
Primitive Types
Handles fundamental data types with enum-based kind tracking:
class PrimType : public Type {
public:
enum TypeKind { kInt, kFloat, kBool };
TypeKind kind;
void print(std::ostream& os) override {
switch(kind) {
case kInt: os << "int"; break;
case kFloat: os << "float"; break;
case kBool: os << "bool"; break;
}
}
};
Tensor Type
Captures shape and dtype information critical for neural network tensors:
class TensorType : public Type {
public:
std::vector<int> shape;
PrimType dtype;
void print(std::ostream& os) override {
os << "Tensor[(";
for(int i = 0; i < shape.size(); i++) {
os << shape[i];
if(i < shape.size() - 1) os << ", ";
}
os << "), ";
dtype.print(os);
os << "]";
}
};
Key Features:
1. Polymorphic Storage
Types are stored as shared_ptr<Type> enabling type-safe container storage:
std::shared_ptr<Type> inputVarType =
std::make_shared<TensorType>(input_shape, PrimType::kFloat);
2. Type-Aware Variables
Relay variables embed type information directly:
class RelayVar : public RelayExpr {
std::shared_ptr<Type> type;
// ...
};
3. Validation Through Printing
The print() method serves dual purpose for both debugging and type validation:
// Test case example
TEST(RelayTypes, TensorTypePrint) {
TensorType tensor({1,3}, PrimType::kFloat);
std::stringstream ss;
tensor.print(ss);
ASSERT_EQ(ss.str(), "Tensor[(1, 3), float]");
}
3. ONNX Model Parsing Architecture
The parser implements a three-phase translation process from ONNX models to Relay IR:
1. Model Inspection & Validation
2. Symbolic Variable Creation
3. Graph Translation Workflow
// Core conversion loop
for(auto node : model.graph().node()) {
// 1. Collect input arguments
std::vector<std::shared_ptr<relay::RelayExpr>> args;
for(const auto& input : node.input()) {
args.push_back(input2relayVars[input]);
}
// 2. Dispatch to operator registry
auto* converter = relay::OperatorRegistry::get_instance();
auto output = converter->convertNode(node, args);
// 3. Map outputs
for(int i = 0; i < output.size(); i++) {
output2relayExprs[node.output(i)] = output[i];
}
}
4. Final Outcome
Now, the onnx to relay converter can parser very simple onnx model like this
$ ./myTvm ../misc/onnx-input/simple_model.onnx
Parsing ONNX model: ../misc/onnx-input/simple_model.onnx
Graph name: main_graph
Inputs:
input : (1, 3, )
Outputs:
output : (1, 2, )
Nodes:
/fc/Gemm : Gemm
input: input
input: fc.weight
input: fc.bias
output: output
def @main(%input: Tensor[(1, 3), float]) {
relay.nn.bias_add(relay.nn.dense(%input, %fc.weight), %fc.bias)
}
This post is mainly about the commit 66006905 in the repo.