NFC: Make the ModuleState field in the ModulePrinter optional.

The ModuleState is only used for printing aliases, which is only done when printing the top-level module.

PiperOrigin-RevId: 264664138
This commit is contained in:
River Riddle 2019-08-21 12:16:23 -07:00 committed by A. Unique TensorFlower
parent b9dc2e4818
commit 2e59b86541
1 changed files with 35 additions and 43 deletions

View File

@ -85,13 +85,7 @@ static constexpr int kNonAttrKindAlias = -1;
class ModuleState {
public:
/// This is the current context if it is knowable, otherwise this is null.
MLIRContext *const context;
explicit ModuleState(MLIRContext *context)
: context(context), interfaces(context) {}
// Initializes module state, populating affine map state.
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
void initialize(Operation *op);
Twine getAttributeAlias(Attribute attr) const {
@ -308,7 +302,6 @@ void ModuleState::initializeSymbolAliases() {
typeToAlias.insert(typeAliasPair);
}
// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(Operation *op) {
// Initialize the symbol aliases.
initializeSymbolAliases();
@ -324,7 +317,8 @@ void ModuleState::initialize(Operation *op) {
namespace {
class ModulePrinter {
public:
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
ModulePrinter(raw_ostream &os, ModuleState *state = nullptr)
: os(os), state(state) {}
explicit ModulePrinter(ModulePrinter &printer)
: os(printer.os), state(printer.state) {}
@ -351,9 +345,6 @@ public:
void printIntegerSet(IntegerSet set);
protected:
raw_ostream &os;
ModuleState &state;
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {});
void printTrailingLocation(Location loc);
@ -370,6 +361,12 @@ protected:
void printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
/// The output stream for the printer.
raw_ostream &os;
/// An optional printer state for the module.
ModuleState *state;
};
} // end anonymous namespace
@ -593,10 +590,12 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
}
// Check for an alias for this attribute.
Twine alias = state.getAttributeAlias(attr);
if (!alias.isTriviallyEmpty()) {
os << '#' << alias;
return;
if (state) {
Twine alias = state->getAttributeAlias(attr);
if (!alias.isTriviallyEmpty()) {
os << '#' << alias;
return;
}
}
switch (attr.getKind()) {
@ -805,10 +804,12 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
void ModulePrinter::printType(Type type) {
// Check for an alias for this type.
StringRef alias = state.getTypeAlias(type);
if (!alias.empty()) {
os << '!' << alias;
return;
if (state) {
StringRef alias = state->getTypeAlias(type);
if (!alias.empty()) {
os << '!' << alias;
return;
}
}
switch (type.getKind()) {
@ -1623,8 +1624,10 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
void ModulePrinter::print(ModuleOp module) {
// Output the aliases at the top level.
state.printAttributeAliases(os);
state.printTypeAliases(os);
if (state) {
state->printAttributeAliases(os);
state->printTypeAliases(os);
}
// Print the module.
OperationPrinter(module, *this).print(module);
@ -1636,8 +1639,7 @@ void ModulePrinter::print(ModuleOp module) {
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).printAttribute(*this);
ModulePrinter(os).printAttribute(*this);
}
void Attribute::dump() const {
@ -1645,10 +1647,7 @@ void Attribute::dump() const {
llvm::errs() << "\n";
}
void Type::print(raw_ostream &os) {
ModuleState state(getContext());
ModulePrinter(os, state).printType(*this);
}
void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
void Type::dump() { print(llvm::errs()); }
@ -1667,8 +1666,7 @@ void AffineExpr::print(raw_ostream &os) const {
os << "null affine expr";
return;
}
ModuleState state(getContext());
ModulePrinter(os, state).printAffineExpr(*this);
ModulePrinter(os).printAffineExpr(*this);
}
void AffineExpr::dump() const {
@ -1681,13 +1679,11 @@ void AffineMap::print(raw_ostream &os) const {
os << "null affine map";
return;
}
ModuleState state(getContext());
ModulePrinter(os, state).printAffineMap(*this);
ModulePrinter(os).printAffineMap(*this);
}
void IntegerSet::print(raw_ostream &os) const {
ModuleState state(getContext());
ModulePrinter(os, state).printIntegerSet(*this);
ModulePrinter(os).printIntegerSet(*this);
}
void Value::print(raw_ostream &os) {
@ -1706,8 +1702,7 @@ void Value::dump() { print(llvm::errs()); }
void Operation::print(raw_ostream &os) {
// Handle top-level operations.
if (!getParent()) {
ModuleState state(getContext());
ModulePrinter modulePrinter(os, state);
ModulePrinter modulePrinter(os);
OperationPrinter(this, modulePrinter).print(this);
return;
}
@ -1722,8 +1717,7 @@ void Operation::print(raw_ostream &os) {
while (auto *nextRegion = region->getParentRegion())
region = nextRegion;
ModuleState state(getContext());
ModulePrinter modulePrinter(os, state);
ModulePrinter modulePrinter(os);
OperationPrinter(region, modulePrinter).print(this);
}
@ -1743,8 +1737,7 @@ void Block::print(raw_ostream &os) {
while (auto *nextRegion = region->getParentRegion())
region = nextRegion;
ModuleState state(region->getContext());
ModulePrinter modulePrinter(os, state);
ModulePrinter modulePrinter(os);
OperationPrinter(region, modulePrinter).print(this);
}
@ -1762,15 +1755,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
while (auto *nextRegion = region->getParentRegion())
region = nextRegion;
ModuleState state(region->getContext());
ModulePrinter modulePrinter(os, state);
ModulePrinter modulePrinter(os);
OperationPrinter(region, modulePrinter).printBlockName(this);
}
void ModuleOp::print(raw_ostream &os) {
ModuleState state(getContext());
state.initialize(*this);
ModulePrinter(os, state).print(*this);
ModulePrinter(os, &state).print(*this);
}
void ModuleOp::dump() { print(llvm::errs()); }