372 lines
13 KiB
C++
372 lines
13 KiB
C++
//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/FunctionImplementation.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::ml_program;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Custom asm helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parse and print an ordering clause for a variadic of consuming tokens
|
|
/// and an producing token.
|
|
///
|
|
/// Syntax:
|
|
/// ordering(%0, %1 -> !ml_program.token)
|
|
/// ordering(() -> !ml_program.token)
|
|
///
|
|
/// If both the consuming and producing token are not present on the op, then
|
|
/// the clause prints nothing.
|
|
static ParseResult parseTokenOrdering(
|
|
OpAsmParser &parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
|
|
Type &produceTokenType) {
|
|
if (failed(parser.parseOptionalKeyword("ordering")) ||
|
|
failed(parser.parseLParen()))
|
|
return success();
|
|
|
|
// Parse consuming token list. If there are no consuming tokens, the
|
|
// '()' null list represents this.
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
} else {
|
|
if (failed(parser.parseOperandList(consumeTokens,
|
|
/*requiredOperandCount=*/-1)))
|
|
return failure();
|
|
}
|
|
|
|
// Parse producer token.
|
|
if (failed(parser.parseArrow()))
|
|
return failure();
|
|
if (failed(parser.parseType(produceTokenType)))
|
|
return failure();
|
|
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
|
|
OperandRange consumeTokens,
|
|
Type produceTokenType) {
|
|
if (consumeTokens.empty() && !produceTokenType)
|
|
return;
|
|
|
|
p << " ordering(";
|
|
if (consumeTokens.empty())
|
|
p << "()";
|
|
else
|
|
p.printOperands(consumeTokens);
|
|
if (produceTokenType) {
|
|
p << " -> ";
|
|
p.printType(produceTokenType);
|
|
}
|
|
p << ")";
|
|
}
|
|
|
|
/// some.op custom<TypeOrAttr>($type, $attr)
|
|
///
|
|
/// Uninitialized:
|
|
/// some.op : tensor<3xi32>
|
|
/// Initialized to narrower type than op:
|
|
/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
|
|
static ParseResult parseTypedInitialValue(OpAsmParser &parser,
|
|
TypeAttr &typeAttr, Attribute &attr) {
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (failed(parser.parseAttribute(attr)))
|
|
return failure();
|
|
if (failed(parser.parseRParen()))
|
|
return failure();
|
|
}
|
|
|
|
Type type;
|
|
if (failed(parser.parseColonType(type)))
|
|
return failure();
|
|
typeAttr = TypeAttr::get(type);
|
|
return success();
|
|
}
|
|
|
|
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
|
|
TypeAttr type, Attribute attr) {
|
|
if (attr) {
|
|
p << "(";
|
|
p.printAttribute(attr);
|
|
p << ")";
|
|
}
|
|
|
|
p << " : ";
|
|
p.printAttribute(type);
|
|
}
|
|
|
|
/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
|
|
/// ->
|
|
/// some.op public @foo
|
|
/// some.op private @foo
|
|
static ParseResult parseSymbolVisibility(OpAsmParser &parser,
|
|
StringAttr &symVisibilityAttr) {
|
|
StringRef symVisibility;
|
|
(void)parser.parseOptionalKeyword(&symVisibility,
|
|
{"public", "private", "nested"});
|
|
if (symVisibility.empty())
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "expected 'public', 'private', or 'nested'";
|
|
if (!symVisibility.empty())
|
|
symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
|
|
return success();
|
|
}
|
|
|
|
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
|
|
StringAttr symVisibilityAttr) {
|
|
if (!symVisibilityAttr)
|
|
p << "public";
|
|
else
|
|
p << symVisibilityAttr.getValue();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FuncOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false, buildFuncType);
|
|
}
|
|
|
|
void FuncOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult GlobalOp::verify() {
|
|
if (!getIsMutable() && !getValue())
|
|
return emitOpError() << "immutable global must have an initial value";
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType()) {
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadConstOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getIsMutable())
|
|
return emitOpError() << "cannot load as const from mutable global "
|
|
<< getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType())
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalLoadGraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (referrent.getType() != getResult().getType()) {
|
|
return emitOpError() << "cannot load from global typed "
|
|
<< referrent.getType() << " as "
|
|
<< getResult().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalStoreOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (!referrent.getIsMutable()) {
|
|
return emitOpError() << "cannot store to an immutable global "
|
|
<< getGlobal();
|
|
}
|
|
|
|
if (referrent.getType() != getValue().getType()) {
|
|
return emitOpError() << "cannot store to a global typed "
|
|
<< referrent.getType() << " from "
|
|
<< getValue().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GlobalStoreGraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
|
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
|
getOperation()->getParentOp(), getGlobalAttr());
|
|
}
|
|
|
|
LogicalResult
|
|
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|
GlobalOp referrent = getGlobalOp(symbolTable);
|
|
if (!referrent)
|
|
return emitOpError() << "undefined global: " << getGlobal();
|
|
|
|
if (!referrent.getIsMutable()) {
|
|
return emitOpError() << "cannot store to an immutable global "
|
|
<< getGlobal();
|
|
}
|
|
|
|
if (referrent.getType() != getValue().getType()) {
|
|
return emitOpError() << "cannot store to a global typed "
|
|
<< referrent.getType() << " from "
|
|
<< getValue().getType();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubgraphOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
auto buildFuncType =
|
|
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
|
function_interface_impl::VariadicFlag,
|
|
std::string &) { return builder.getFunctionType(argTypes, results); };
|
|
|
|
return function_interface_impl::parseFunctionOp(
|
|
parser, result, /*allowVariadic=*/false, buildFuncType);
|
|
}
|
|
|
|
void SubgraphOp::print(OpAsmPrinter &p) {
|
|
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OutputOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult OutputOp::verify() {
|
|
auto function = cast<SubgraphOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") outputs " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of output operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReturnOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult ReturnOp::verify() {
|
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
|
|
|
// The operand number and types must match the function signature.
|
|
const auto &results = function.getFunctionType().getResults();
|
|
if (getNumOperands() != results.size())
|
|
return emitOpError("has ")
|
|
<< getNumOperands() << " operands, but enclosing function (@"
|
|
<< function.getName() << ") returns " << results.size();
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
if (getOperand(i).getType() != results[i])
|
|
return emitError() << "type of return operand " << i << " ("
|
|
<< getOperand(i).getType()
|
|
<< ") doesn't match function result type ("
|
|
<< results[i] << ")"
|
|
<< " in function @" << function.getName();
|
|
|
|
return success();
|
|
}
|