llvm-project/mlir/lib/LLVMIR/IR/LLVMDialect.cpp

892 lines
35 KiB
C++

//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the types and operation details for the LLVM IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::LLVM;
namespace mlir {
namespace LLVM {
namespace detail {
struct LLVMTypeStorage : public ::mlir::TypeStorage {
LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
// LLVM types are pointer-unique.
using KeyTy = llvm::Type *;
bool operator==(const KeyTy &key) const { return key == underlyingType; }
static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
llvm::Type *ty) {
return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
}
llvm::Type *underlyingType;
};
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir
LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
}
llvm::Type *LLVMType::getUnderlyingType() const {
return getImpl()->underlyingType;
}
static void printLLVMBinaryOp(OpAsmPrinter *p, Operation *op) {
// Fallback to the generic form if the op is not well-formed (may happen
// during incomplete rewrites, and used for debugging).
const auto *abstract = op->getAbstractOperation();
(void)abstract;
assert(abstract && "pretty printing an unregistered operation");
auto resultType = op->getResult(0)->getType();
if (resultType != op->getOperand(0)->getType() ||
resultType != op->getOperand(1)->getType())
return p->printGenericOp(op);
*p << op->getName().getStringRef() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
*p << " : " << op->getResult(0)->getType();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ICmpOp.
//===----------------------------------------------------------------------===//
// Return an array of mnemonics for ICmpPredicates indexed by its value.
static const char *const *getICmpPredicateNames() {
static const char *predicateNames[]{/*EQ*/ "eq",
/*NE*/ "ne",
/*SLT*/ "slt",
/*SLE*/ "sle",
/*SGT*/ "sgt",
/*SGE*/ "sge",
/*ULT*/ "ult",
/*ULE*/ "ule",
/*UGT*/ "ugt",
/*UGE*/ "uge"};
return predicateNames;
}
// Returns a value of the ICmp predicate corresponding to the given mnemonic.
// Returns -1 if there is no such mnemonic.
static int getICmpPredicateByName(StringRef name) {
return llvm::StringSwitch<int>(name)
.Case("eq", 0)
.Case("ne", 1)
.Case("slt", 2)
.Case("sle", 3)
.Case("sgt", 4)
.Case("sge", 5)
.Case("ult", 6)
.Case("ule", 7)
.Case("ugt", 8)
.Case("uge", 9)
.Default(-1);
}
static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
*p << op.getOperationName() << " \""
<< getICmpPredicateNames()[op.predicate().getZExtValue()] << "\" "
<< *op.getOperand(0) << ", " << *op.getOperand(1);
p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
*p << " : " << op.lhs()->getType();
}
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
// attribute-dict? `:` type
static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
Builder &builder = parser->getBuilder();
Attribute predicate;
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType lhs, rhs;
Type type;
llvm::SMLoc predicateLoc, trailingTypeLoc;
if (parser->getCurrentLocation(&predicateLoc) ||
parser->parseAttribute(predicate, "predicate", attrs) ||
parser->parseOperand(lhs) || parser->parseComma() ||
parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type) ||
parser->resolveOperand(lhs, type, result->operands) ||
parser->resolveOperand(rhs, type, result->operands))
return failure();
// Replace the string attribute `predicate` with an integer attribute.
auto predicateStr = predicate.dyn_cast<StringAttr>();
if (!predicateStr)
return parser->emitError(predicateLoc,
"expected 'predicate' attribute of string type");
int predicateValue = getICmpPredicateByName(predicateStr.getValue());
if (predicateValue == -1)
return parser->emitError(predicateLoc)
<< "'" << predicateStr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue);
// The result type is either i1 or a vector type <? x i1> if the inputs are
// vectors.
auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
llvm::Type *llvmResultType = llvm::Type::getInt1Ty(dialect->getLLVMContext());
auto argType = type.dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
if (argType.getUnderlyingType()->isVectorTy())
llvmResultType = llvm::VectorType::get(
llvmResultType, argType.getUnderlyingType()->getVectorNumElements());
auto resultType = builder.getType<LLVM::LLVMType>(llvmResultType);
result->attributes = attrs;
result->addTypes({resultType});
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::AllocaOp.
//===----------------------------------------------------------------------===//
static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
auto *llvmPtrTy = op.getType().cast<LLVM::LLVMType>().getUnderlyingType();
auto *llvmElemTy = llvm::cast<llvm::PointerType>(llvmPtrTy)->getElementType();
auto elemTy = LLVM::LLVMType::get(op.getContext(), llvmElemTy);
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
op.getContext());
*p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << funcTy;
}
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
// `:` type `,` type
static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arraySize;
Type type, elemType;
llvm::SMLoc trailingTypeLoc;
if (parser->parseOperand(arraySize) || parser->parseKeyword("x") ||
parser->parseType(elemType) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return failure();
// Extract the result type from the trailing function type.
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType || funcType.getNumInputs() != 1 ||
funcType.getNumResults() != 1)
return parser->emitError(
trailingTypeLoc,
"expected trailing function type with one argument and one result");
if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands))
return failure();
result->attributes = attrs;
result->addTypes({funcType.getResult(0)});
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::GEPOp.
//===----------------------------------------------------------------------===//
static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
SmallVector<Type, 8> types;
for (auto *operand : op.getOperands())
types.push_back(operand->getType());
auto funcTy =
FunctionType::get(types, op.getResult()->getType(), op.getContext());
*p << op.getOperationName() << ' ' << *op.base() << '[';
p->printOperands(std::next(op.operand_begin()), op.operand_end());
*p << ']';
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << funcTy;
}
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
// attribute-dict? `:` type
static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType base;
SmallVector<OpAsmParser::OperandType, 8> indices;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser->parseOperand(base) ||
parser->parseOperandList(indices, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return failure();
// Deconstruct the trailing function type to extract the types of the base
// pointer and result (same type) and the types of the indices.
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType || funcType.getNumResults() != 1 ||
funcType.getNumInputs() == 0)
return parser->emitError(trailingTypeLoc,
"expected trailing function type with at least "
"one argument and one result");
if (parser->resolveOperand(base, funcType.getInput(0), result->operands) ||
parser->resolveOperands(indices, funcType.getInputs().drop_front(),
parser->getNameLoc(), result->operands))
return failure();
result->attributes = attrs;
result->addTypes(funcType.getResults());
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
static void printLoadOp(OpAsmPrinter *p, LoadOp &op) {
*p << op.getOperationName() << ' ' << *op.addr();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.addr()->getType();
}
// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
// the resulting type wrapped in MLIR, or nullptr on error.
static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
llvm::SMLoc trailingTypeLoc) {
auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
if (!llvmTy)
return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
nullptr;
auto *llvmPtrTy = dyn_cast<llvm::PointerType>(llvmTy.getUnderlyingType());
if (!llvmPtrTy)
return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
nullptr;
auto elemTy = LLVM::LLVMType::get(parser->getBuilder().getContext(),
llvmPtrTy->getElementType());
return elemTy;
}
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type) ||
parser->resolveOperand(addr, type, result->operands))
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
result->attributes = attrs;
result->addTypes(elemTy);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
*p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.addr()->getType();
}
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr, value;
Type type;
llvm::SMLoc trailingTypeLoc;
if (parser->parseOperand(value) || parser->parseComma() ||
parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type))
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!elemTy)
return failure();
if (parser->resolveOperand(value, elemTy, result->operands) ||
parser->resolveOperand(addr, type, result->operands))
return failure();
result->attributes = attrs;
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::BitcastOp.
//===----------------------------------------------------------------------===//
static void printBitcastOp(OpAsmPrinter *p, BitcastOp &op) {
*p << op.getOperationName() << ' ' << *op.arg();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.arg()->getType() << " to " << op.getType();
}
// <operation> ::= `llvm.bitcast` ssa-use attribute-dict? `:` type `to` type
static ParseResult parseBitcastOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arg;
Type sourceType, type;
if (parser->parseOperand(arg) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(sourceType) || parser->parseKeyword("to") ||
parser->parseType(type) ||
parser->resolveOperand(arg, sourceType, result->operands))
return failure();
result->attributes = attrs;
result->addTypes(type);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CallOp.
//===----------------------------------------------------------------------===//
static void printCallOp(OpAsmPrinter *p, CallOp &op) {
auto callee = op.callee();
bool isDirect = callee.hasValue();
// Print the direct callee if present as a function attribute, or an indirect
// callee (first operand) otherwise.
*p << op.getOperationName() << ' ';
if (isDirect)
*p << '@' << callee.getValue();
else
*p << *op.getOperand(0);
*p << '(';
p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
*p << ')';
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
SmallVector<Type, 8> argTypes;
argTypes.reserve(op.getNumOperands());
for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
argTypes.push_back(operand->getType());
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
}
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
// attribute-dict? `:` function-type
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
SmallVector<OpAsmParser::OperandType, 8> operands;
Type type;
StringRef calleeName;
llvm::SMLoc calleeLoc, trailingTypeLoc;
// Parse an operand list that will, in practice, contain 0 or 1 operand. In
// case of an indirect call, there will be 1 operand before `(`. In case of a
// direct call, there will be no operands and the parser will stop at the
// function identifier without complaining.
if (parser->parseOperandList(operands))
return failure();
bool isDirect = operands.empty();
// Optionally parse a function identifier.
if (isDirect)
if (parser->parseFunctionName(calleeName, calleeLoc))
return failure();
if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return failure();
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType)
return parser->emitError(trailingTypeLoc, "expected function type");
if (isDirect) {
// Add the direct callee as an Op attribute.
auto funcAttr = parser->getBuilder().getFunctionAttr(calleeName);
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
// Make sure types match.
if (parser->resolveOperands(operands, funcType.getInputs(),
parser->getNameLoc(), result->operands))
return failure();
result->addTypes(funcType.getResults());
} else {
// Construct the LLVM IR Dialect function type that the first operand
// should match.
if (funcType.getNumResults() > 1)
return parser->emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
llvm::Type *llvmResultType;
Type wrappedResultType;
if (funcType.getNumResults() == 0) {
llvmResultType = llvm::Type::getVoidTy(llvmDialect->getLLVMContext());
wrappedResultType = builder.getType<LLVM::LLVMType>(llvmResultType);
} else {
wrappedResultType = funcType.getResult(0);
auto wrappedLLVMResultType = wrappedResultType.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMResultType)
return parser->emitError(trailingTypeLoc,
"expected result to have LLVM type");
llvmResultType = wrappedLLVMResultType.getUnderlyingType();
}
SmallVector<llvm::Type *, 8> argTypes;
argTypes.reserve(funcType.getNumInputs());
for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
if (!argType)
return parser->emitError(trailingTypeLoc,
"expected LLVM types as inputs");
argTypes.push_back(argType.getUnderlyingType());
}
auto *llvmFuncType = llvm::FunctionType::get(llvmResultType, argTypes,
/*isVarArg=*/false);
auto wrappedFuncType =
builder.getType<LLVM::LLVMType>(llvmFuncType->getPointerTo());
auto funcArguments =
ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
// Make sure that the first operand (indirect callee) matches the wrapped
// LLVM IR function type, and that the types of the other call operands
// match the types of the function arguments.
if (parser->resolveOperand(operands[0], wrappedFuncType,
result->operands) ||
parser->resolveOperands(funcArguments, funcType.getInputs(),
parser->getNameLoc(), result->operands))
return failure();
result->addTypes(wrappedResultType);
}
result->attributes = attrs;
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ExtractValueOp.
//===----------------------------------------------------------------------===//
static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) {
*p << op.getOperationName() << ' ' << *op.container() << op.position();
p->printOptionalAttrDict(op.getAttrs(), {"position"});
*p << " : " << op.container()->getType();
}
// Extract the type at `position` in the wrapped LLVM IR aggregate type
// `containerType`. Position is an integer array attribute where each value
// is a zero-based position of the element in the aggregate type. Return the
// resulting type wrapped in MLIR, or nullptr on error.
static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
Type containerType,
Attribute positionAttr,
llvm::SMLoc attributeLoc,
llvm::SMLoc typeLoc) {
auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
if (!wrappedContainerType)
return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>();
if (!positionArrayAttr)
return parser->emitError(attributeLoc, "expected an array attribute"),
nullptr;
// Infer the element type from the structure type: iteratively step inside the
// type by taking the element type, indexed by the position attribute for
// stuctures. Check the position index before accessing, it is supposed to be
// in bounds.
llvm::Type *llvmContainerType = wrappedContainerType.getUnderlyingType();
for (Attribute subAttr : positionArrayAttr) {
auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
if (!positionElementAttr)
return parser->emitError(attributeLoc,
"expected an array of integer literals"),
nullptr;
int position = positionElementAttr.getInt();
if (llvmContainerType->isArrayTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getArrayNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
llvmContainerType = llvmContainerType->getArrayElementType();
} else if (llvmContainerType->isStructTy()) {
if (position < 0 || static_cast<unsigned>(position) >=
llvmContainerType->getStructNumElements())
return parser->emitError(attributeLoc, "position out of bounds"),
nullptr;
llvmContainerType = llvmContainerType->getStructElementType(position);
} else {
return parser->emitError(typeLoc,
"expected wrapped LLVM IR structure/array type"),
nullptr;
}
}
Builder &builder = parser->getBuilder();
return builder.getType<LLVM::LLVMType>(llvmContainerType);
}
// <operation> ::= `llvm.extractvalue` ssa-use
// `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type
static ParseResult parseExtractValueOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType container;
Type containerType;
Attribute positionAttr;
llvm::SMLoc attributeLoc, trailingTypeLoc;
if (parser->parseOperand(container) ||
parser->getCurrentLocation(&attributeLoc) ||
parser->parseAttribute(positionAttr, "position", attrs) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType) ||
parser->resolveOperand(container, containerType, result->operands))
return failure();
auto elementType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!elementType)
return failure();
result->attributes = attrs;
result->addTypes(elementType);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::InsertValueOp.
//===----------------------------------------------------------------------===//
static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
*p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
<< op.position();
p->printOptionalAttrDict(op.getAttrs(), {"position"});
*p << " : " << op.container()->getType();
}
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
// `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type
static ParseResult parseInsertValueOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType container, value;
Type containerType;
Attribute positionAttr;
llvm::SMLoc attributeLoc, trailingTypeLoc;
if (parser->parseOperand(value) || parser->parseComma() ||
parser->parseOperand(container) ||
parser->getCurrentLocation(&attributeLoc) ||
parser->parseAttribute(positionAttr, "position", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType))
return failure();
auto valueType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!valueType)
return failure();
if (parser->resolveOperand(container, containerType, result->operands) ||
parser->resolveOperand(value, valueType, result->operands))
return failure();
result->addTypes(containerType);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::SelectOp.
//===----------------------------------------------------------------------===//
static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
*p << op.getOperationName() << ' ' << *op.condition() << ", "
<< *op.trueValue() << ", " << *op.falseValue();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
}
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
// attribute-dict? `:` type, type
static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType condition, trueValue, falseValue;
Type conditionType, argType;
if (parser->parseOperand(condition) || parser->parseComma() ||
parser->parseOperand(trueValue) || parser->parseComma() ||
parser->parseOperand(falseValue) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(conditionType) || parser->parseComma() ||
parser->parseType(argType))
return failure();
if (parser->resolveOperand(condition, conditionType, result->operands) ||
parser->resolveOperand(trueValue, argType, result->operands) ||
parser->resolveOperand(falseValue, argType, result->operands))
return failure();
result->addTypes(argType);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::BrOp.
//===----------------------------------------------------------------------===//
static void printBrOp(OpAsmPrinter *p, BrOp &op) {
*p << op.getOperationName() << ' ';
p->printSuccessorAndUseList(op.getOperation(), 0);
p->printOptionalAttrDict(op.getAttrs());
}
// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
// attribute-dict?
static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) {
Block *dest;
SmallVector<Value *, 4> operands;
if (parser->parseSuccessorAndUseList(dest, operands) ||
parser->parseOptionalAttributeDict(result->attributes))
return failure();
result->addSuccessor(dest, operands);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CondBrOp.
//===----------------------------------------------------------------------===//
static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
*p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
p->printSuccessorAndUseList(op.getOperation(), 0);
*p << ", ";
p->printSuccessorAndUseList(op.getOperation(), 1);
p->printOptionalAttrDict(op.getAttrs());
}
// <operation> ::= `llvm.cond_br` ssa-use `,`
// bb-id (`[` ssa-use-and-type-list `]`)? `,`
// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
Block *trueDest;
Block *falseDest;
SmallVector<Value *, 4> trueOperands;
SmallVector<Value *, 4> falseOperands;
OpAsmParser::OperandType condition;
Builder &builder = parser->getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto i1Type = builder.getType<LLVM::LLVMType>(
llvm::Type::getInt1Ty(llvmDialect->getLLVMContext()));
if (parser->parseOperand(condition) || parser->parseComma() ||
parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
parser->parseComma() ||
parser->parseSuccessorAndUseList(falseDest, falseOperands) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->resolveOperand(condition, i1Type, result->operands))
return failure();
result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ReturnOp.
//===----------------------------------------------------------------------===//
static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
*p << op.getOperationName();
p->printOptionalAttrDict(op.getAttrs());
assert(op.getNumOperands() <= 1);
if (op.getNumOperands() == 0)
return;
*p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
}
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
// type-list-no-parens
static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 1> operands;
Type type;
if (parser->parseOperandList(operands) ||
parser->parseOptionalAttributeDict(result->attributes))
return failure();
if (operands.empty())
return success();
if (parser->parseColonType(type) ||
parser->resolveOperand(operands[0], type, result->operands))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::UndefOp.
//===----------------------------------------------------------------------===//
static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
*p << op.getOperationName();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.res()->getType();
}
// <operation> ::= `llvm.undef` attribute-dict? : type
static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return failure();
result->addTypes(type);
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
*p << op.getOperationName() << '(' << op.value();
// Print attribute types other than i64 and f64 because attribute parsing will
// assume those in absence of explicit attribute type.
if (auto intAttr = op.value().dyn_cast<IntegerAttr>()) {
auto type = intAttr.getType();
if (!type.isInteger(64))
*p << " : " << intAttr.getType();
} else if (auto floatAttr = op.value().dyn_cast<FloatAttr>()) {
auto type = floatAttr.getType();
if (!type.isF64())
*p << " : " << type;
}
*p << ')';
p->printOptionalAttrDict(op.getAttrs(), {"value"});
*p << " : " << op.res()->getType();
}
// <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type
static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr;
Type type;
if (parser->parseLParen() ||
parser->parseAttribute(valueAttr, "value", result->attributes) ||
parser->parseRParen() ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return failure();
result->addTypes(type);
return success();
}
//===----------------------------------------------------------------------===//
// LLVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
LLVMDialect::LLVMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context),
module("LLVMDialectModule", llvmContext) {
addTypes<LLVMType>();
addOperations<
#define GET_OP_LIST
#include "mlir/LLVMIR/LLVMOps.cpp.inc"
>();
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
}
#define GET_OP_CLASSES
#include "mlir/LLVMIR/LLVMOps.cpp.inc"
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
llvm::SMDiagnostic errorMessage;
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
if (!type)
return (getContext()->emitError(loc, errorMessage.getMessage()), nullptr);
return LLVMType::get(getContext(), type);
}
/// Print a type registered to this dialect.
void LLVMDialect::printType(Type type, raw_ostream &os) const {
auto llvmType = type.dyn_cast<LLVMType>();
assert(llvmType && "printing wrong type");
assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
llvmType.getUnderlyingType()->print(os);
}
/// Verify LLVMIR function argument attributes.
LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
unsigned argIdx,
NamedAttribute argAttr) {
// Check that llvm.noalias is a boolean attribute.
if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
return func->emitError()
<< "llvm.noalias argument attribute of non boolean type";
return success();
}
static DialectRegistration<LLVMDialect> llvmDialect;