- Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][ODS] Fix default inferReturnTypes generation for variadic operands#131483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base:main
Are you sure you want to change the base?
[mlir][ODS] Fix default inferReturnTypes generation for variadic operands #131483
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesFor variadic operands, Full diff: https://github.com/llvm/llvm-project/pull/131483.diff 2 Files Affected:
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index a4f7af6dbcf1c..334ca118e31c0 100644 --- a/mlir/test/mlir-tblgen/op-result.td+++ b/mlir/test/mlir-tblgen/op-result.td@@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 0)-// CHECK-NEXT: return ::mlir::failure();-// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();+// CHECK: OpL1::Adaptor adaptor+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; def OpL2 : NS_Op<"op_with_all_types_constraint", @@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 2)-// CHECK-NEXT: return ::mlir::failure();-// CHECK-NOT: if (operands.size() <= 0)-// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();-// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();+// CHECK: OpL2::Adaptor adaptor+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType();+// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; // CHECK: inferredReturnTypes[1] = odsInferredType1; @@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes -// CHECK: if (operands.size() <= 0)-// CHECK-NEXT: return ::mlir::failure();-// CHECK: odsInferredType0 = fromInput(operands[0].getType())+// CHECK: OpL4::Adaptor adaptor+// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) // CHECK: inferredReturnTypes[0] = odsInferredType0 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b957c8ee9f8ab..8288e77b8f653 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() { // Avoid emitting "resultTypes.size() >= 0u" which is always true. if (!hasVariadicResult || numNonVariadicResults != 0) - body << " "- << "assert(resultTypes.size() "+ body << " " << "assert(resultTypes.size() " << (hasVariadicResult ? ">=" : "==") << " " << numNonVariadicResults << "u && \"mismatched number of results\");\n"; @@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; - // Preprocessing stage to verify all accesses to operands are valid.- int maxAccessedIndex = -1;- for (int i = 0, e = op.getNumResults(); i != e; ++i) {- const InferredResultType &infer = op.getInferredResultType(i);- if (!infer.isArg())- continue;- Operator::OperandOrAttribute arg =- op.getArgToOperandOrAttribute(infer.getIndex());- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {- maxAccessedIndex =- std::max(maxAccessedIndex, arg.operandOrAttributeIndex());- }- }- if (maxAccessedIndex != -1) {- body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";- body << " return ::mlir::failure();\n";- }+ // Emit an adaptor to access right ranges for ods operands.+ body << " " << op.getCppClassName()+ << "::Adaptor adaptor(operands, attributes, properties, regions);\n";- // Process the type inference graph in topological order, starting from types- // that are always fully-inferred: operands and results with constructible- // types. The type inference graph here will always be a DAG, so this gives- // us the correct order for generating the types. -1 is a placeholder to- // indicate the type for a result has not been generated.+ // TODO: Ideally, we should be doing some sort of verification here. This+ // is however problemetic due to 2 reasons:+ //+ // 1. Adaptor::verify only verifies attributes. It really should verify+ // if the number of given attributes is right too.+ // 2. PDL passes empty properties to inferReturnTypes, which does not verify.+ // Without properties, it's not really possible to verify the number of+ // operands as we do not know the variadic operand segment sizes.++ // Process the type inference graph in topological order, starting from+ // types that are always fully-inferred: operands and results with+ // constructible types. The type inference graph here will always be a+ // DAG, so this gives us the correct order for generating the types. -1 is+ // a placeholder to indicate the type for a result has not been generated. SmallVector<int> constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { @@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() { Operator::OperandOrAttribute arg = op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +- "].getType()")- .str();-+ std::string getter =+ "adaptor." ++ op.getGetterName(+ op.getOperand(arg.operandOrAttributeIndex()).name);+ typeStr = (getter + "().getType()"); // If this is an attribute, index into the attribute dictionary. } else { auto *attr = |
I kind of remember that InferReturnTypes can be called by the verifier on invalid IR, but the adaptor aren't resilient to invalid IR I believe: are we risking crashes during verification here? |
Also: could the new behavior be tested in the test dialect? |
3cdf6e7
to 1beaad3
Compare
This is okay, inferReturnTypes documents 1 that the operands passed should already be verified. We are only using the operand accessors for verifiers so if we crash, it is expected documented behavior. It's really hard to verify if the operands are correct otherwise (we might just go into an infinite verification loop otherwise, although i could be wrong here). |
Added, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, thanks!
@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() { | |||
// Avoid emitting "resultTypes.size() >= 0u" which is always true. | |||
if (!hasVariadicResult || numNonVariadicResults != 0) | |||
body << " " | |||
<< "assert(resultTypes.size() " | |||
body << " " << "assert(resultTypes.size() " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably can remove the second << here and merge the strings, not sure its doing much here (if it were handling indentation other story)
} | ||
} | ||
if (maxAccessedIndex != -1) { | ||
body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this check now handled in the adaptor? (I don't recall if there is a test for this upstream)
For variadic operands,
operands[odsOperandIndex]
is incorrect, because the operand can be variadic. Instead, create an adaptor and use it to get the correct operand.