Skip to content

[mlir][ODS] Fix default inferReturnTypes generation for variadic operands#131483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base:main
Choose a base branch
from

Conversation

Groverkss
Copy link
Member

For variadic operands, operands[odsOperandIndex] is incorrect, because the operand can be variadic. Instead, create an adaptor and use it to get the correct operand.

@llvmbotllvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 16, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

For variadic operands, operands[odsOperandIndex] is incorrect, because the operand can be variadic. Instead, create an adaptor and use it to get the correct operand.


Full diff: https://github.com/llvm/llvm-project/pull/131483.diff

2 Files Affected:

  • (modified) mlir/test/mlir-tblgen/op-result.td (+7-11)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+23-28)
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index a4f7af6dbcf1c..334ca118e31c0 100644 --- a/mlir/test/mlir-tblgen/op-result.td+++ b/mlir/test/mlir-tblgen/op-result.td@@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 0)-// CHECK-NEXT: return ::mlir::failure();-// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();+// CHECK: OpL1::Adaptor adaptor+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; def OpL2 : NS_Op<"op_with_all_types_constraint", @@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 2)-// CHECK-NEXT: return ::mlir::failure();-// CHECK-NOT: if (operands.size() <= 0)-// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();-// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();+// CHECK: OpL2::Adaptor adaptor+// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType();+// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; // CHECK: inferredReturnTypes[1] = odsInferredType1; @@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes -// CHECK: if (operands.size() <= 0)-// CHECK-NEXT: return ::mlir::failure();-// CHECK: odsInferredType0 = fromInput(operands[0].getType())+// CHECK: OpL4::Adaptor adaptor+// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) // CHECK: inferredReturnTypes[0] = odsInferredType0 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b957c8ee9f8ab..8288e77b8f653 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() { // Avoid emitting "resultTypes.size() >= 0u" which is always true. if (!hasVariadicResult || numNonVariadicResults != 0) - body << " "- << "assert(resultTypes.size() "+ body << " " << "assert(resultTypes.size() " << (hasVariadicResult ? ">=" : "==") << " " << numNonVariadicResults << "u && \"mismatched number of results\");\n"; @@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; - // Preprocessing stage to verify all accesses to operands are valid.- int maxAccessedIndex = -1;- for (int i = 0, e = op.getNumResults(); i != e; ++i) {- const InferredResultType &infer = op.getInferredResultType(i);- if (!infer.isArg())- continue;- Operator::OperandOrAttribute arg =- op.getArgToOperandOrAttribute(infer.getIndex());- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {- maxAccessedIndex =- std::max(maxAccessedIndex, arg.operandOrAttributeIndex());- }- }- if (maxAccessedIndex != -1) {- body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";- body << " return ::mlir::failure();\n";- }+ // Emit an adaptor to access right ranges for ods operands.+ body << " " << op.getCppClassName()+ << "::Adaptor adaptor(operands, attributes, properties, regions);\n";- // Process the type inference graph in topological order, starting from types- // that are always fully-inferred: operands and results with constructible- // types. The type inference graph here will always be a DAG, so this gives- // us the correct order for generating the types. -1 is a placeholder to- // indicate the type for a result has not been generated.+ // TODO: Ideally, we should be doing some sort of verification here. This+ // is however problemetic due to 2 reasons:+ //+ // 1. Adaptor::verify only verifies attributes. It really should verify+ // if the number of given attributes is right too.+ // 2. PDL passes empty properties to inferReturnTypes, which does not verify.+ // Without properties, it's not really possible to verify the number of+ // operands as we do not know the variadic operand segment sizes.++ // Process the type inference graph in topological order, starting from+ // types that are always fully-inferred: operands and results with+ // constructible types. The type inference graph here will always be a+ // DAG, so this gives us the correct order for generating the types. -1 is+ // a placeholder to indicate the type for a result has not been generated. SmallVector<int> constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { @@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() { Operator::OperandOrAttribute arg = op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +- "].getType()")- .str();-+ std::string getter =+ "adaptor." ++ op.getGetterName(+ op.getOperand(arg.operandOrAttributeIndex()).name);+ typeStr = (getter + "().getType()"); // If this is an attribute, index into the attribute dictionary. } else { auto *attr = 
@joker-eph
Copy link
Collaborator

I kind of remember that InferReturnTypes can be called by the verifier on invalid IR, but the adaptor aren't resilient to invalid IR I believe: are we risking crashes during verification here?

@joker-eph
Copy link
Collaborator

Also: could the new behavior be tested in the test dialect?

@GroverkssGroverkssforce-pushed the fix-ods-variadic-infer-return-types branch from 3cdf6e7 to 1beaad3CompareMarch 19, 2025 11:49
@Groverkss
Copy link
MemberAuthor

I kind of remember that InferReturnTypes can be called by the verifier on invalid IR, but the adaptor aren't resilient to invalid IR I believe: are we risking crashes during verification here?

This is okay, inferReturnTypes documents 1 that the operands passed should already be verified. We are only using the operand accessors for verifiers so if we crash, it is expected documented behavior. It's really hard to verify if the operands are correct otherwise (we might just go into an infinite verification loop otherwise, although i could be wrong here).

@Groverkss
Copy link
MemberAuthor

Also: could the new behavior be tested in the test dialect?

Added, thanks!

Copy link
Member

@jpienaarjpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, thanks!

@@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() {

// Avoid emitting "resultTypes.size() >= 0u" which is always true.
if (!hasVariadicResult || numNonVariadicResults != 0)
body << " "
<< "assert(resultTypes.size() "
body << " " << "assert(resultTypes.size() "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably can remove the second << here and merge the strings, not sure its doing much here (if it were handling indentation other story)

}
}
if (maxAccessedIndex != -1) {
body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check now handled in the adaptor? (I don't recall if there is a test for this upstream)

Sign up for freeto join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:coreMLIR Core Infrastructuremlir
5 participants
@Groverkss@llvmbot@joker-eph@jpienaar@Mogball
close