NFC: Make the ModuleState field in the ModulePrinter optional.
The ModuleState is only used for printing aliases, which is only done when printing the top-level module. PiperOrigin-RevId: 264664138
This commit is contained in:
parent
b9dc2e4818
commit
2e59b86541
|
@ -85,13 +85,7 @@ static constexpr int kNonAttrKindAlias = -1;
|
|||
|
||||
class ModuleState {
|
||||
public:
|
||||
/// This is the current context if it is knowable, otherwise this is null.
|
||||
MLIRContext *const context;
|
||||
|
||||
explicit ModuleState(MLIRContext *context)
|
||||
: context(context), interfaces(context) {}
|
||||
|
||||
// Initializes module state, populating affine map state.
|
||||
explicit ModuleState(MLIRContext *context) : interfaces(context) {}
|
||||
void initialize(Operation *op);
|
||||
|
||||
Twine getAttributeAlias(Attribute attr) const {
|
||||
|
@ -308,7 +302,6 @@ void ModuleState::initializeSymbolAliases() {
|
|||
typeToAlias.insert(typeAliasPair);
|
||||
}
|
||||
|
||||
// Initializes module state, populating affine map and integer set state.
|
||||
void ModuleState::initialize(Operation *op) {
|
||||
// Initialize the symbol aliases.
|
||||
initializeSymbolAliases();
|
||||
|
@ -324,7 +317,8 @@ void ModuleState::initialize(Operation *op) {
|
|||
namespace {
|
||||
class ModulePrinter {
|
||||
public:
|
||||
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
|
||||
ModulePrinter(raw_ostream &os, ModuleState *state = nullptr)
|
||||
: os(os), state(state) {}
|
||||
explicit ModulePrinter(ModulePrinter &printer)
|
||||
: os(printer.os), state(printer.state) {}
|
||||
|
||||
|
@ -351,9 +345,6 @@ public:
|
|||
void printIntegerSet(IntegerSet set);
|
||||
|
||||
protected:
|
||||
raw_ostream &os;
|
||||
ModuleState &state;
|
||||
|
||||
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<StringRef> elidedAttrs = {});
|
||||
void printTrailingLocation(Location loc);
|
||||
|
@ -370,6 +361,12 @@ protected:
|
|||
void printAffineExprInternal(
|
||||
AffineExpr expr, BindingStrength enclosingTightness,
|
||||
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
|
||||
|
||||
/// The output stream for the printer.
|
||||
raw_ostream &os;
|
||||
|
||||
/// An optional printer state for the module.
|
||||
ModuleState *state;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -593,10 +590,12 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
|
|||
}
|
||||
|
||||
// Check for an alias for this attribute.
|
||||
Twine alias = state.getAttributeAlias(attr);
|
||||
if (!alias.isTriviallyEmpty()) {
|
||||
os << '#' << alias;
|
||||
return;
|
||||
if (state) {
|
||||
Twine alias = state->getAttributeAlias(attr);
|
||||
if (!alias.isTriviallyEmpty()) {
|
||||
os << '#' << alias;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
switch (attr.getKind()) {
|
||||
|
@ -805,10 +804,12 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|||
|
||||
void ModulePrinter::printType(Type type) {
|
||||
// Check for an alias for this type.
|
||||
StringRef alias = state.getTypeAlias(type);
|
||||
if (!alias.empty()) {
|
||||
os << '!' << alias;
|
||||
return;
|
||||
if (state) {
|
||||
StringRef alias = state->getTypeAlias(type);
|
||||
if (!alias.empty()) {
|
||||
os << '!' << alias;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
switch (type.getKind()) {
|
||||
|
@ -1623,8 +1624,10 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
|
|||
|
||||
void ModulePrinter::print(ModuleOp module) {
|
||||
// Output the aliases at the top level.
|
||||
state.printAttributeAliases(os);
|
||||
state.printTypeAliases(os);
|
||||
if (state) {
|
||||
state->printAttributeAliases(os);
|
||||
state->printTypeAliases(os);
|
||||
}
|
||||
|
||||
// Print the module.
|
||||
OperationPrinter(module, *this).print(module);
|
||||
|
@ -1636,8 +1639,7 @@ void ModulePrinter::print(ModuleOp module) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Attribute::print(raw_ostream &os) const {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printAttribute(*this);
|
||||
ModulePrinter(os).printAttribute(*this);
|
||||
}
|
||||
|
||||
void Attribute::dump() const {
|
||||
|
@ -1645,10 +1647,7 @@ void Attribute::dump() const {
|
|||
llvm::errs() << "\n";
|
||||
}
|
||||
|
||||
void Type::print(raw_ostream &os) {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printType(*this);
|
||||
}
|
||||
void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
|
||||
|
||||
void Type::dump() { print(llvm::errs()); }
|
||||
|
||||
|
@ -1667,8 +1666,7 @@ void AffineExpr::print(raw_ostream &os) const {
|
|||
os << "null affine expr";
|
||||
return;
|
||||
}
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printAffineExpr(*this);
|
||||
ModulePrinter(os).printAffineExpr(*this);
|
||||
}
|
||||
|
||||
void AffineExpr::dump() const {
|
||||
|
@ -1681,13 +1679,11 @@ void AffineMap::print(raw_ostream &os) const {
|
|||
os << "null affine map";
|
||||
return;
|
||||
}
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printAffineMap(*this);
|
||||
ModulePrinter(os).printAffineMap(*this);
|
||||
}
|
||||
|
||||
void IntegerSet::print(raw_ostream &os) const {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).printIntegerSet(*this);
|
||||
ModulePrinter(os).printIntegerSet(*this);
|
||||
}
|
||||
|
||||
void Value::print(raw_ostream &os) {
|
||||
|
@ -1706,8 +1702,7 @@ void Value::dump() { print(llvm::errs()); }
|
|||
void Operation::print(raw_ostream &os) {
|
||||
// Handle top-level operations.
|
||||
if (!getParent()) {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
ModulePrinter modulePrinter(os);
|
||||
OperationPrinter(this, modulePrinter).print(this);
|
||||
return;
|
||||
}
|
||||
|
@ -1722,8 +1717,7 @@ void Operation::print(raw_ostream &os) {
|
|||
while (auto *nextRegion = region->getParentRegion())
|
||||
region = nextRegion;
|
||||
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
ModulePrinter modulePrinter(os);
|
||||
OperationPrinter(region, modulePrinter).print(this);
|
||||
}
|
||||
|
||||
|
@ -1743,8 +1737,7 @@ void Block::print(raw_ostream &os) {
|
|||
while (auto *nextRegion = region->getParentRegion())
|
||||
region = nextRegion;
|
||||
|
||||
ModuleState state(region->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
ModulePrinter modulePrinter(os);
|
||||
OperationPrinter(region, modulePrinter).print(this);
|
||||
}
|
||||
|
||||
|
@ -1762,15 +1755,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
|
|||
while (auto *nextRegion = region->getParentRegion())
|
||||
region = nextRegion;
|
||||
|
||||
ModuleState state(region->getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
ModulePrinter modulePrinter(os);
|
||||
OperationPrinter(region, modulePrinter).printBlockName(this);
|
||||
}
|
||||
|
||||
void ModuleOp::print(raw_ostream &os) {
|
||||
ModuleState state(getContext());
|
||||
state.initialize(*this);
|
||||
ModulePrinter(os, state).print(*this);
|
||||
ModulePrinter(os, &state).print(*this);
|
||||
}
|
||||
|
||||
void ModuleOp::dump() { print(llvm::errs()); }
|
||||
|
|
Loading…
Reference in New Issue