@@ -26,9 +26,9 @@ ValueCategory::ValueCategory(mlir::Value val, bool isReference)
2626 val.getType ().isa <LLVM::LLVMPointerType>())) {
2727 llvm::errs () << " val: " << val << " \n " ;
2828 }
29- assert (val.getType ().isa <MemRefType>() ||
30- val.getType ().isa <LLVM::LLVMPointerType>() &&
31- " Reference value must have pointer/memref type" );
29+ assert (( val.getType ().isa <MemRefType>() ||
30+ val.getType ().isa <LLVM::LLVMPointerType>() ) &&
31+ " Reference value must have pointer/memref type" );
3232 }
3333}
3434
@@ -54,12 +54,26 @@ void ValueCategory::store(mlir::OpBuilder &builder, mlir::Value toStore) const {
5454 assert (val && " expect not-null" );
5555 auto loc = builder.getUnknownLoc ();
5656 if (auto pt = val.getType ().dyn_cast <mlir::LLVM::LLVMPointerType>()) {
57+ if (auto p2m = toStore.getDefiningOp <polygeist::Pointer2MemrefOp>()) {
58+ if (pt.getElementType () == p2m.source ().getType ())
59+ toStore = p2m.source ();
60+ else if (auto nt = p2m.source ().getDefiningOp <LLVM::NullOp>()) {
61+ if (pt.getElementType ().isa <LLVM::LLVMPointerType>())
62+ toStore =
63+ builder.create <LLVM::NullOp>(nt.getLoc (), pt.getElementType ());
64+ }
65+ }
5766 if (toStore.getType () != pt.getElementType ()) {
5867 if (auto mt = toStore.getType ().dyn_cast <MemRefType>()) {
5968 if (auto spt =
6069 pt.getElementType ().dyn_cast <mlir::LLVM::LLVMPointerType>()) {
61- assert (mt.getElementType () == spt.getElementType () &&
62- " expect same type" );
70+ if (mt.getElementType () != spt.getElementType ()) {
71+ // llvm::errs() << " func: " <<
72+ // val.getDefiningOp()->getParentOfType<FuncOp>() << "\n";
73+ llvm::errs () << " warning potential store type mismatch:\n " ;
74+ llvm::errs () << " val: " << val << " tosval: " << toStore << " \n " ;
75+ llvm::errs () << " mt: " << mt << " spt: " << spt << " \n " ;
76+ }
6377 toStore =
6478 builder.create <polygeist::Memref2PointerOp>(loc, spt, toStore);
6579 }
0 commit comments