diff --git a/example/ExampleDialect.td b/example/ExampleDialect.td index 8f6d172..1c7d7fd 100644 --- a/example/ExampleDialect.td +++ b/example/ExampleDialect.td @@ -367,7 +367,7 @@ def StructBackedType : DialectType { let description = [{ Test that a struct-backed type works correctly. }]; - let typeArguments = (args AttrI32:$field0, AttrI32:$field1, AttrI32:$field2); + let typeArguments = (args AttrI32:$field0, AttrI8:$field1, AttrVectorKind:$field2); let representation = (repr_struct (IntegerType 41)); let defaultGetterHasExplicitContextArgument = 1; diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index b78c601..763a9ab 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -149,7 +149,7 @@ void createFunctionExample(Module &module, const Twine &name) { b.create("Hello world!"); - xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, 2); + xd::cpp::StructBackedType *structBackedTy = xd::cpp::StructBackedType::get(bb->getContext(), 1, 0, xd::cpp::VectorKind::BigEndian); auto *structBackedVal = b.create(structBackedTy, b.getInt32(42), "gen.struct.backed.val"); b.create(structBackedVal, "consume.struct.backed.val"); diff --git a/lib/TableGen/DialectType.cpp b/lib/TableGen/DialectType.cpp index 17c2916..fc0b53c 100644 --- a/lib/TableGen/DialectType.cpp +++ b/lib/TableGen/DialectType.cpp @@ -201,17 +201,24 @@ void DialectType::emitDeclaration(raw_ostream &out, GenDialect *dialect) const { out << " static bool classof(const ::llvm::Type *t);\n\n"; unsigned fieldIdx = 1; // sentinel + auto getCastExpr = [&fmt](const NamedValue &argument, + llvm::StringRef expr) -> std::string { + return tgfmt(cast(argument.type)->getFromUnsigned(), &fmt, expr); + }; for (const auto &argument : typeArguments()) { std::string camel = convertToCamelFromSnakeCase(argument.name, true); out << tgfmt( - R"( unsigned get$0() const { - ::llvm::Type *elt = getElementType($1); + R"( $0 get$1() const { + ::llvm::Type *elt = getElementType($2); if (elt->isStructTy()) - return 0; - return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + return $3; + return $4; } )", - &fmt, camel, fieldIdx++); + &fmt, argument.type->getCppType(), camel, fieldIdx++, + getCastExpr(argument, "0"), + getCastExpr(argument, + "::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()")); } out << " };\n\n"; @@ -307,14 +314,17 @@ void DialectType::emitDefinition(raw_ostream &out, GenDialect *dialect) const { " $fields.push_back(::llvm::IntegerType::get($_context, $0));\n", &fmt, Twine(m_structSentinelBitWidth)); - for (const auto &getterArg : getterArgs) { + for (const auto &[argument, getterArg] : + llvm::zip(typeArguments(), getterArgs)) { + std::string castExpr = tgfmt(cast(argument.type)->getToUnsigned(), + &fmt, getterArg.name); out << tgfmt(R"( if ($0 == 0) $fields.push_back(::llvm::StructType::get($_context)); else $fields.push_back(::llvm::IntegerType::get($_context, $0)); )", - &fmt, getterArg.name); + &fmt, castExpr); } out << tgfmt(" auto *$st = ::llvm::StructType::create($_context, " "$fields, $os.str(), /*isPacked=*/false);\n", diff --git a/test/example/generated/ExampleDialect.cpp.inc b/test/example/generated/ExampleDialect.cpp.inc index 2a39b5b..902dc84 100644 --- a/test/example/generated/ExampleDialect.cpp.inc +++ b/test/example/generated/ExampleDialect.cpp.inc @@ -258,10 +258,10 @@ m_attributeLists[6] = argAttrList.addFnAttributes(context, attrBuilder); } } -StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2) { - +StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2) { +static_assert(sizeof(field2) <= sizeof(unsigned)); std::string name; ::llvm::raw_string_ostream os(name); os << "struct.backed"; os << '.' << (uint64_t)field0; @@ -280,10 +280,10 @@ StructBackedType* StructBackedType::get(::llvm::LLVMContext & ctx, uint32_t fiel else fields.push_back(::llvm::IntegerType::get(ctx, field1)); - if (field2 == 0) + if (static_cast(field2) == 0) fields.push_back(::llvm::StructType::get(ctx)); else - fields.push_back(::llvm::IntegerType::get(ctx, field2)); + fields.push_back(::llvm::IntegerType::get(ctx, static_cast(field2))); auto *st = ::llvm::StructType::create(ctx, fields, os.str(), /*isPacked=*/false); return static_cast(st); } diff --git a/test/example/generated/ExampleDialect.h.inc b/test/example/generated/ExampleDialect.h.inc index 46189b9..8530a3a 100644 --- a/test/example/generated/ExampleDialect.h.inc +++ b/test/example/generated/ExampleDialect.h.inc @@ -58,27 +58,27 @@ namespace xd::cpp { using ::llvm::StructType::getElementType; static StructBackedType *get( - ::llvm::LLVMContext & ctx, uint32_t field0, uint32_t field1, uint32_t field2); + ::llvm::LLVMContext & ctx, uint32_t field0, uint8_t field1, VectorKind field2); static bool classof(const ::llvm::Type *t); - unsigned getField0() const { + uint32_t getField0() const { ::llvm::Type *elt = getElementType(1); if (elt->isStructTy()) return 0; return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); } - unsigned getField1() const { + uint8_t getField1() const { ::llvm::Type *elt = getElementType(2); if (elt->isStructTy()) return 0; return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); } - unsigned getField2() const { + VectorKind getField2() const { ::llvm::Type *elt = getElementType(3); if (elt->isStructTy()) - return 0; - return ::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth(); + return static_cast(0); + return static_cast(::llvm::cast<::llvm::IntegerType>(elt)->getBitWidth()); } };