llvm-project/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp

372 lines
13 KiB
C++

//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
// Custom asm helpers
//===----------------------------------------------------------------------===//
/// Parse and print an ordering clause for a variadic of consuming tokens
/// and an producing token.
///
/// Syntax:
/// ordering(%0, %1 -> !ml_program.token)
/// ordering(() -> !ml_program.token)
///
/// If both the consuming and producing token are not present on the op, then
/// the clause prints nothing.
static ParseResult parseTokenOrdering(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
Type &produceTokenType) {
if (failed(parser.parseOptionalKeyword("ordering")) ||
failed(parser.parseLParen()))
return success();
// Parse consuming token list. If there are no consuming tokens, the
// '()' null list represents this.
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRParen()))
return failure();
} else {
if (failed(parser.parseOperandList(consumeTokens,
/*requiredOperandCount=*/-1)))
return failure();
}
// Parse producer token.
if (failed(parser.parseArrow()))
return failure();
if (failed(parser.parseType(produceTokenType)))
return failure();
if (failed(parser.parseRParen()))
return failure();
return success();
}
static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
OperandRange consumeTokens,
Type produceTokenType) {
if (consumeTokens.empty() && !produceTokenType)
return;
p << " ordering(";
if (consumeTokens.empty())
p << "()";
else
p.printOperands(consumeTokens);
if (produceTokenType) {
p << " -> ";
p.printType(produceTokenType);
}
p << ")";
}
/// some.op custom<TypeOrAttr>($type, $attr)
///
/// Uninitialized:
/// some.op : tensor<3xi32>
/// Initialized to narrower type than op:
/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
static ParseResult parseTypedInitialValue(OpAsmParser &parser,
TypeAttr &typeAttr, Attribute &attr) {
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseAttribute(attr)))
return failure();
if (failed(parser.parseRParen()))
return failure();
}
Type type;
if (failed(parser.parseColonType(type)))
return failure();
typeAttr = TypeAttr::get(type);
return success();
}
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
TypeAttr type, Attribute attr) {
if (attr) {
p << "(";
p.printAttribute(attr);
p << ")";
}
p << " : ";
p.printAttribute(type);
}
/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
/// ->
/// some.op public @foo
/// some.op private @foo
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
StringAttr &symVisibilityAttr) {
StringRef symVisibility;
(void)parser.parseOptionalKeyword(&symVisibility,
{"public", "private", "nested"});
if (symVisibility.empty())
return parser.emitError(parser.getCurrentLocation())
<< "expected 'public', 'private', or 'nested'";
if (!symVisibility.empty())
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
return success();
}
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
StringAttr symVisibilityAttr) {
if (!symVisibilityAttr)
p << "public";
else
p << symVisibilityAttr.getValue();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
LogicalResult GlobalOp::verify() {
if (!getIsMutable() && !getValue())
return emitOpError() << "immutable global must have an initial value";
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getType() != getResult().getType()) {
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadConstOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getIsMutable())
return emitOpError() << "cannot load as const from mutable global "
<< getGlobal();
if (referrent.getType() != getResult().getType())
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getType() != getResult().getType()) {
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (!referrent.getIsMutable()) {
return emitOpError() << "cannot store to an immutable global "
<< getGlobal();
}
if (referrent.getType() != getValue().getType()) {
return emitOpError() << "cannot store to a global typed "
<< referrent.getType() << " from "
<< getValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (!referrent.getIsMutable()) {
return emitOpError() << "cannot store to an immutable global "
<< getGlobal();
}
if (referrent.getType() != getValue().getType()) {
return emitOpError() << "cannot store to a global typed "
<< referrent.getType() << " from "
<< getValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// OutputOp
//===----------------------------------------------------------------------===//
LogicalResult OutputOp::verify() {
auto function = cast<SubgraphOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") outputs " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of output operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
auto function = cast<FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}