Merge SSAValue, CFGValue, and MLValue together into a single Value class, which

is the new base of the SSA value hierarchy.  This CL also standardizes all the
nomenclature and comments to use 'Value' where appropriate.  This also eliminates a large number of cast<MLValue>(x)'s, which is very soothing.

This is step 11/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227064624
This commit is contained in:
Chris Lattner 2018-12-27 14:35:10 -08:00 committed by jpienaar
parent 776b035646
commit 3f190312f8
61 changed files with 959 additions and 1162 deletions

View File

@ -37,9 +37,9 @@ class ForStmt;
class MLIRContext; class MLIRContext;
class FlatAffineConstraints; class FlatAffineConstraints;
class IntegerSet; class IntegerSet;
class MLValue;
class OperationStmt; class OperationStmt;
class Statement; class Statement;
class Value;
/// Simplify an affine expression through flattening and some amount of /// Simplify an affine expression through flattening and some amount of
/// simple analysis. This has complexity linear in the number of nodes in /// simple analysis. This has complexity linear in the number of nodes in
@ -78,7 +78,7 @@ AffineExpr composeWithUnboundedMap(AffineExpr e, AffineMap g);
/// 'affineApplyOps', which are reachable via a search starting from 'operands', /// 'affineApplyOps', which are reachable via a search starting from 'operands',
/// and ending at operands which are not defined by AffineApplyOps. /// and ending at operands which are not defined by AffineApplyOps.
void getReachableAffineApplyOps( void getReachableAffineApplyOps(
llvm::ArrayRef<MLValue *> operands, llvm::ArrayRef<Value *> operands,
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps); llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
/// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the /// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
@ -122,9 +122,9 @@ bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain); FlatAffineConstraints *domain);
struct MemRefAccess { struct MemRefAccess {
const MLValue *memref; const Value *memref;
const OperationStmt *opStmt; const OperationStmt *opStmt;
llvm::SmallVector<MLValue *, 4> indices; llvm::SmallVector<Value *, 4> indices;
// Populates 'accessMap' with composition of AffineApplyOps reachable from // Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'. // 'indices'.
void getAccessMap(AffineValueMap *accessMap) const; void getAccessMap(AffineValueMap *accessMap) const;

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
namespace mlir { namespace mlir {
@ -37,7 +36,7 @@ class AffineMap;
class ForStmt; class ForStmt;
class IntegerSet; class IntegerSet;
class MLIRContext; class MLIRContext;
class MLValue; class Value;
class HyperRectangularSet; class HyperRectangularSet;
/// A mutable affine map. Its affine expressions are however unique. /// A mutable affine map. Its affine expressions are however unique.
@ -132,7 +131,7 @@ public:
AffineValueMap(const AffineApplyOp &op); AffineValueMap(const AffineApplyOp &op);
AffineValueMap(const AffineBound &bound); AffineValueMap(const AffineBound &bound);
AffineValueMap(AffineMap map); AffineValueMap(AffineMap map);
AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands); AffineValueMap(AffineMap map, ArrayRef<Value *> operands);
~AffineValueMap(); ~AffineValueMap();
@ -155,13 +154,13 @@ public:
// substitutions). // substitutions).
// Resets this AffineValueMap with 'map' and 'operands'. // Resets this AffineValueMap with 'map' and 'operands'.
void reset(AffineMap map, ArrayRef<MLValue *> operands); void reset(AffineMap map, ArrayRef<Value *> operands);
/// Return true if the idx^th result can be proved to be a multiple of /// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise. /// 'factor', false otherwise.
inline bool isMultipleOf(unsigned idx, int64_t factor) const; inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the idx^th result depends on 'value', false otherwise. /// Return true if the idx^th result depends on 'value', false otherwise.
bool isFunctionOf(unsigned idx, MLValue *value) const; bool isFunctionOf(unsigned idx, Value *value) const;
/// Return true if the result at 'idx' is a constant, false /// Return true if the result at 'idx' is a constant, false
/// otherwise. /// otherwise.
@ -175,8 +174,8 @@ public:
inline unsigned getNumSymbols() const { return map.getNumSymbols(); } inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
inline unsigned getNumResults() const { return map.getNumResults(); } inline unsigned getNumResults() const { return map.getNumResults(); }
SSAValue *getOperand(unsigned i) const; Value *getOperand(unsigned i) const;
ArrayRef<MLValue *> getOperands() const; ArrayRef<Value *> getOperands() const;
AffineMap getAffineMap() const; AffineMap getAffineMap() const;
private: private:
@ -187,9 +186,9 @@ private:
// TODO: make these trailing objects? // TODO: make these trailing objects?
/// The SSA operands binding to the dim's and symbols of 'map'. /// The SSA operands binding to the dim's and symbols of 'map'.
SmallVector<MLValue *, 4> operands; SmallVector<Value *, 4> operands;
/// The SSA results binding to the results of 'map'. /// The SSA results binding to the results of 'map'.
SmallVector<MLValue *, 4> results; SmallVector<Value *, 4> results;
}; };
/// An IntegerValueSet is an integer set plus its operands. /// An IntegerValueSet is an integer set plus its operands.
@ -218,7 +217,7 @@ private:
// 'AffineCondition'. // 'AffineCondition'.
MutableIntegerSet set; MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'. /// The SSA operands binding to the dim's and symbols of 'set'.
SmallVector<MLValue *, 4> operands; SmallVector<Value *, 4> operands;
}; };
/// A flat list of affine equalities and inequalities in the form. /// A flat list of affine equalities and inequalities in the form.
@ -250,7 +249,7 @@ public:
unsigned numReservedEqualities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims = 0, unsigned numReservedCols, unsigned numDims = 0,
unsigned numSymbols = 0, unsigned numLocals = 0, unsigned numSymbols = 0, unsigned numLocals = 0,
ArrayRef<Optional<MLValue *>> idArgs = {}) ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numReservedCols), numDims(numDims), : numReservedCols(numReservedCols), numDims(numDims),
numSymbols(numSymbols) { numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1); assert(numReservedCols >= numDims + numSymbols + 1);
@ -269,7 +268,7 @@ public:
/// dimensions and symbols. /// dimensions and symbols.
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0, FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0, unsigned numLocals = 0,
ArrayRef<Optional<MLValue *>> idArgs = {}) ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims), : numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
numSymbols(numSymbols) { numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1); assert(numReservedCols >= numDims + numSymbols + 1);
@ -309,10 +308,10 @@ public:
// Clears any existing data and reserves memory for the specified constraints. // Clears any existing data and reserves memory for the specified constraints.
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities, void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims, unsigned numSymbols, unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {}); unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
void reset(unsigned numDims = 0, unsigned numSymbols = 0, void reset(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {}); unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
/// Appends constraints from 'other' into this. This is equivalent to an /// Appends constraints from 'other' into this. This is equivalent to an
/// intersection with no simplification of any sort attempted. /// intersection with no simplification of any sort attempted.
@ -393,7 +392,7 @@ public:
// Returns AffineMap::Null on error (i.e. if coefficient is zero or does // Returns AffineMap::Null on error (i.e. if coefficient is zero or does
// not divide other coefficients in the equality constraint). // not divide other coefficients in the equality constraint).
// TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this // TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this
// API when we can manage the mapping of MLValues and ids in the constraint // API when we can manage the mapping of Values and ids in the constraint
// system. // system.
AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context, AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context,
SmallVectorImpl<unsigned> *nonZeroDimIds, SmallVectorImpl<unsigned> *nonZeroDimIds,
@ -413,10 +412,10 @@ public:
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb); void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
/// Adds constraints (lower and upper bounds) for the specified 'for' /// Adds constraints (lower and upper bounds) for the specified 'for'
/// statement's MLValue using IR information stored in its bound maps. The /// statement's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forStmt's MLValue. Returns /// right identifier is first looked up using forStmt's Value. Returns
/// false for the yet unimplemented/unsupported cases, and true if the /// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the MLValue corresponding to /// information is succesfully added. Asserts if the Value corresponding to
/// the 'for' statement isn't found in the constraint system. Any new /// the 'for' statement isn't found in the constraint system. Any new
/// identifiers that are found in the bound operands of the 'for' statement /// identifiers that are found in the bound operands of the 'for' statement
/// are added as trailing identifiers (either dimensional or symbolic /// are added as trailing identifiers (either dimensional or symbolic
@ -435,28 +434,28 @@ public:
/// Sets the identifier at the specified position to a constant. /// Sets the identifier at the specified position to a constant.
void setIdToConstant(unsigned pos, int64_t val); void setIdToConstant(unsigned pos, int64_t val);
/// Sets the identifier corresponding to the specified MLValue id to a /// Sets the identifier corresponding to the specified Value id to a
/// constant. Asserts if the 'id' is not found. /// constant. Asserts if the 'id' is not found.
void setIdToConstant(const MLValue &id, int64_t val); void setIdToConstant(const Value &id, int64_t val);
/// Looks up the identifier with the specified MLValue. Returns false if not /// Looks up the identifier with the specified Value. Returns false if not
/// found, true if found. pos is set to the (column) position of the /// found, true if found. pos is set to the (column) position of the
/// identifier. /// identifier.
bool findId(const MLValue &id, unsigned *pos) const; bool findId(const Value &id, unsigned *pos) const;
// Add identifiers of the specified kind - specified positions are relative to // Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. 'id' is the MLValue corresponding to the // the kind of identifier. 'id' is the Value corresponding to the
// identifier that can optionally be provided. // identifier that can optionally be provided.
void addDimId(unsigned pos, MLValue *id = nullptr); void addDimId(unsigned pos, Value *id = nullptr);
void addSymbolId(unsigned pos, MLValue *id = nullptr); void addSymbolId(unsigned pos, Value *id = nullptr);
void addLocalId(unsigned pos); void addLocalId(unsigned pos);
void addId(IdKind kind, unsigned pos, MLValue *id = nullptr); void addId(IdKind kind, unsigned pos, Value *id = nullptr);
/// Composes the affine value map with this FlatAffineConstrains, adding the /// Composes the affine value map with this FlatAffineConstrains, adding the
/// results of the map as dimensions at the front [0, vMap->getNumResults()) /// results of the map as dimensions at the front [0, vMap->getNumResults())
/// and with the dimensions set to the equalities specified by the value map. /// and with the dimensions set to the equalities specified by the value map.
/// Returns false if the composition fails (when vMap is a semi-affine map). /// Returns false if the composition fails (when vMap is a semi-affine map).
/// The vMap's operand MLValue's are used to look up the right positions in /// The vMap's operand Value's are used to look up the right positions in
/// the FlatAffineConstraints with which to associate. The dimensional and /// the FlatAffineConstraints with which to associate. The dimensional and
/// symbolic operands of vMap should match 1:1 (in the same order) with those /// symbolic operands of vMap should match 1:1 (in the same order) with those
/// of this constraint system, but the latter could have additional trailing /// of this constraint system, but the latter could have additional trailing
@ -471,8 +470,8 @@ public:
void projectOut(unsigned pos, unsigned num); void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); } inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
/// Projects out the identifier that is associate with MLValue *. /// Projects out the identifier that is associate with Value *.
void projectOut(MLValue *id); void projectOut(Value *id);
void removeId(IdKind idKind, unsigned pos); void removeId(IdKind idKind, unsigned pos);
void removeId(unsigned pos); void removeId(unsigned pos);
@ -510,24 +509,24 @@ public:
return numIds - numDims - numSymbols; return numIds - numDims - numSymbols;
} }
inline ArrayRef<Optional<MLValue *>> getIds() const { inline ArrayRef<Optional<Value *>> getIds() const {
return {ids.data(), ids.size()}; return {ids.data(), ids.size()};
} }
/// Returns the MLValue's associated with the identifiers. Asserts if /// Returns the Value's associated with the identifiers. Asserts if
/// no MLValue was associated with an identifier. /// no Value was associated with an identifier.
inline void getIdValues(SmallVectorImpl<MLValue *> *values) const { inline void getIdValues(SmallVectorImpl<Value *> *values) const {
values->clear(); values->clear();
values->reserve(numIds); values->reserve(numIds);
for (unsigned i = 0; i < numIds; i++) { for (unsigned i = 0; i < numIds; i++) {
assert(ids[i].hasValue() && "identifier's MLValue not set"); assert(ids[i].hasValue() && "identifier's Value not set");
values->push_back(ids[i].getValue()); values->push_back(ids[i].getValue());
} }
} }
/// Returns the MLValue associated with the pos^th identifier. Asserts if /// Returns the Value associated with the pos^th identifier. Asserts if
/// no MLValue identifier was associated. /// no Value identifier was associated.
inline MLValue *getIdValue(unsigned pos) const { inline Value *getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's ML Value not set"); assert(ids[pos].hasValue() && "identifier's ML Value not set");
return ids[pos].getValue(); return ids[pos].getValue();
} }
@ -630,11 +629,11 @@ private:
/// analysis). /// analysis).
unsigned numSymbols; unsigned numSymbols;
/// MLValues corresponding to the (column) identifiers of this constraint /// Values corresponding to the (column) identifiers of this constraint
/// system appearing in the order the identifiers correspond to columns. /// system appearing in the order the identifiers correspond to columns.
/// Temporary ones or those that aren't associated to any MLValue are to be /// Temporary ones or those that aren't associated to any Value are to be
/// set to None. /// set to None.
SmallVector<Optional<MLValue *>, 8> ids; SmallVector<Optional<Value *>, 8> ids;
}; };
} // end namespace mlir. } // end namespace mlir.

View File

@ -58,10 +58,10 @@ public:
} }
/// Return true if value A properly dominates instruction B. /// Return true if value A properly dominates instruction B.
bool properlyDominates(const SSAValue *a, const Instruction *b); bool properlyDominates(const Value *a, const Instruction *b);
/// Return true if instruction A dominates instruction B. /// Return true if instruction A dominates instruction B.
bool dominates(const SSAValue *a, const Instruction *b) { bool dominates(const Value *a, const Instruction *b) {
return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b); return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b);
} }

View File

@ -38,10 +38,10 @@ class AffineCondition;
class AffineMap; class AffineMap;
class IntegerSet; class IntegerSet;
class MLIRContext; class MLIRContext;
class MLValue;
class MutableIntegerSet; class MutableIntegerSet;
class FlatAffineConstraints; class FlatAffineConstraints;
class HyperRectangleList; class HyperRectangleList;
class Value;
/// A list of affine bounds. /// A list of affine bounds.
// Not using a MutableAffineMap here since numSymbols is the same as the // Not using a MutableAffineMap here since numSymbols is the same as the
@ -152,8 +152,8 @@ private:
// expressions. // expressions.
std::vector<AffineBoundExprList> upperBounds; std::vector<AffineBoundExprList> upperBounds;
Optional<SmallVector<MLValue *, 8>> dims = None; Optional<SmallVector<Value *, 8>> dims = None;
Optional<SmallVector<MLValue *, 4>> symbols = None; Optional<SmallVector<Value *, 4>> symbols = None;
/// Number of real dimensions. /// Number of real dimensions.
unsigned numDims; unsigned numDims;

View File

@ -23,7 +23,6 @@
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H #define MLIR_ANALYSIS_LOOP_ANALYSIS_H
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h" #include "llvm/ADT/Optional.h"
@ -33,8 +32,8 @@ class AffineExpr;
class AffineMap; class AffineMap;
class ForStmt; class ForStmt;
class MemRefType; class MemRefType;
class MLValue;
class OperationStmt; class OperationStmt;
class Value;
/// Returns the trip count of the loop as an affine expression if the latter is /// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count /// expressible as an affine expression, and nullptr otherwise. The trip count
@ -66,7 +65,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
/// ///
/// Returns false in cases with more than one AffineApplyOp, this is /// Returns false in cases with more than one AffineApplyOp, this is
/// conservative. /// conservative.
bool isAccessInvariant(const MLValue &iv, const MLValue &index); bool isAccessInvariant(const Value &iv, const Value &index);
/// Given an induction variable `iv` of type ForStmt and `indices` of type /// Given an induction variable `iv` of type ForStmt and `indices` of type
/// IndexType, returns the set of `indices` that are independent of `iv`. /// IndexType, returns the set of `indices` that are independent of `iv`.
@ -77,9 +76,8 @@ bool isAccessInvariant(const MLValue &iv, const MLValue &index);
/// ///
/// Returns false in cases with more than one AffineApplyOp, this is /// Returns false in cases with more than one AffineApplyOp, this is
/// conservative. /// conservative.
llvm::DenseSet<const MLValue *, llvm::DenseMapInfo<const MLValue *>> llvm::DenseSet<const Value *, llvm::DenseMapInfo<const Value *>>
getInvariantAccesses(const MLValue &iv, getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
llvm::ArrayRef<const MLValue *> indices);
/// Checks whether the loop is structurally vectorizable; i.e.: /// Checks whether the loop is structurally vectorizable; i.e.:
/// 1. the loop has proper dependence semantics (parallel, reduction, etc); /// 1. the loop has proper dependence semantics (parallel, reduction, etc);

View File

@ -127,10 +127,10 @@ void getBackwardSlice(
/// **includes** the original statement. /// **includes** the original statement.
/// ///
/// This allows building a slice (i.e. multi-root DAG where everything /// This allows building a slice (i.e. multi-root DAG where everything
/// that is reachable from an SSAValue in forward and backward direction is /// that is reachable from an Value in forward and backward direction is
/// contained in the slice). /// contained in the slice).
/// This is the abstraction we need to materialize all the instructions for /// This is the abstraction we need to materialize all the instructions for
/// supervectorization without worrying about orderings and SSAValue /// supervectorization without worrying about orderings and Value
/// replacements. /// replacements.
/// ///
/// Example starting from any node /// Example starting from any node

View File

@ -34,10 +34,10 @@ namespace mlir {
class FlatAffineConstraints; class FlatAffineConstraints;
class ForStmt; class ForStmt;
class MLValue;
class MemRefAccess; class MemRefAccess;
class OperationStmt; class OperationStmt;
class Statement; class Statement;
class Value;
/// Returns true if statement 'a' dominates statement b. /// Returns true if statement 'a' dominates statement b.
bool dominates(const Statement &a, const Statement &b); bool dominates(const Statement &a, const Statement &b);
@ -92,7 +92,7 @@ struct MemRefRegion {
unsigned getRank() const; unsigned getRank() const;
/// Memref that this region corresponds to. /// Memref that this region corresponds to.
MLValue *memref; Value *memref;
private: private:
/// Read or write. /// Read or write.

View File

@ -408,8 +408,8 @@ public:
} }
// Creates a for statement. When step is not specified, it is set to 1. // Creates a for statement. When step is not specified, it is set to 1.
ForStmt *createFor(Location location, ArrayRef<MLValue *> lbOperands, ForStmt *createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands, AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step = 1); AffineMap ubMap, int64_t step = 1);
// Creates a for statement with known (constant) lower and upper bounds. // Creates a for statement with known (constant) lower and upper bounds.
@ -417,7 +417,7 @@ public:
ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
/// Creates if statement. /// Creates if statement.
IfStmt *createIf(Location location, ArrayRef<MLValue *> operands, IfStmt *createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set); IntegerSet set);
private: private:

View File

@ -30,7 +30,6 @@
namespace mlir { namespace mlir {
class Builder; class Builder;
class MLValue;
class BuiltinDialect : public Dialect { class BuiltinDialect : public Dialect {
public: public:
@ -57,7 +56,7 @@ class AffineApplyOp
public: public:
/// Builds an affine apply op with the specified map and operands. /// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map, static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<SSAValue *> operands); ArrayRef<Value *> operands);
/// Returns the affine map to be applied by this operation. /// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const { AffineMap getAffineMap() const {
@ -101,7 +100,7 @@ public:
static StringRef getOperationName() { return "br"; } static StringRef getOperationName() { return "br"; }
static void build(Builder *builder, OperationState *result, BasicBlock *dest, static void build(Builder *builder, OperationState *result, BasicBlock *dest,
ArrayRef<SSAValue *> operands = {}); ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
@ -144,10 +143,9 @@ class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
public: public:
static StringRef getOperationName() { return "cond_br"; } static StringRef getOperationName() { return "cond_br"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *condition,
SSAValue *condition, BasicBlock *trueDest, BasicBlock *trueDest, ArrayRef<Value *> trueOperands,
ArrayRef<SSAValue *> trueOperands, BasicBlock *falseDest, BasicBlock *falseDest, ArrayRef<Value *> falseOperands);
ArrayRef<SSAValue *> falseOperands);
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
@ -155,8 +153,8 @@ public:
bool verify() const; bool verify() const;
// The condition operand is the first operand in the list. // The condition operand is the first operand in the list.
SSAValue *getCondition() { return getOperand(0); } Value *getCondition() { return getOperand(0); }
const SSAValue *getCondition() const { return getOperand(0); } const Value *getCondition() const { return getOperand(0); }
/// Return the destination if the condition is true. /// Return the destination if the condition is true.
BasicBlock *getTrueDest() const; BasicBlock *getTrueDest() const;
@ -165,14 +163,14 @@ public:
BasicBlock *getFalseDest() const; BasicBlock *getFalseDest() const;
// Accessors for operands to the 'true' destination. // Accessors for operands to the 'true' destination.
SSAValue *getTrueOperand(unsigned idx) { Value *getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands()); assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx); return getOperand(getTrueDestOperandIndex() + idx);
} }
const SSAValue *getTrueOperand(unsigned idx) const { const Value *getTrueOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getTrueOperand(idx); return const_cast<CondBranchOp *>(this)->getTrueOperand(idx);
} }
void setTrueOperand(unsigned idx, SSAValue *value) { void setTrueOperand(unsigned idx, Value *value) {
assert(idx < getNumTrueOperands()); assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value); setOperand(getTrueDestOperandIndex() + idx, value);
} }
@ -199,14 +197,14 @@ public:
void eraseTrueOperand(unsigned index); void eraseTrueOperand(unsigned index);
// Accessors for operands to the 'false' destination. // Accessors for operands to the 'false' destination.
SSAValue *getFalseOperand(unsigned idx) { Value *getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands()); assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx); return getOperand(getFalseDestOperandIndex() + idx);
} }
const SSAValue *getFalseOperand(unsigned idx) const { const Value *getFalseOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx); return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
} }
void setFalseOperand(unsigned idx, SSAValue *value) { void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands()); assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value); setOperand(getFalseDestOperandIndex() + idx, value);
} }
@ -361,7 +359,7 @@ public:
static StringRef getOperationName() { return "return"; } static StringRef getOperationName() { return "return"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
ArrayRef<SSAValue *> results = {}); ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
@ -380,7 +378,7 @@ void printDimAndSymbolList(Operation::const_operand_iterator begin,
// Parses dimension and symbol list and returns true if parsing failed. // Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser, bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<SSAValue *, 4> &operands, SmallVector<Value *, 4> &operands,
unsigned &numDims); unsigned &numDims);
} // end namespace mlir } // end namespace mlir

View File

@ -27,7 +27,6 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/StmtBlock.h" #include "mlir/IR/StmtBlock.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"

View File

@ -1,133 +0,0 @@
//===- MLValue.h - MLValue base class and SSA type decls ------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines SSA manipulation implementations for ML functions.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLVALUE_H
#define MLIR_IR_MLVALUE_H
#include "mlir/IR/SSAValue.h"
namespace mlir {
class ForStmt;
class MLValue;
using MLFunction = Function;
class Statement;
class StmtBlock;
/// This enum contains all of the SSA value kinds that are valid in an ML
/// function. This should be kept as a proper subtype of SSAValueKind,
/// including having all of the values of the enumerators align.
enum class MLValueKind {
BlockArgument = (int)SSAValueKind::BlockArgument,
StmtResult = (int)SSAValueKind::StmtResult,
ForStmt = (int)SSAValueKind::ForStmt,
};
/// The operand of ML function statement contains an MLValue.
using StmtOperand = IROperandImpl<MLValue, Statement>;
/// MLValue is the base class for SSA values in ML functions.
class MLValue : public SSAValueImpl<StmtOperand, Statement, MLValueKind> {
public:
/// Returns true if the given MLValue can be used as a dimension id.
bool isValidDim() const;
/// Returns true if the given MLValue can be used as a symbol.
bool isValidSymbol() const;
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
return true;
}
}
/// Return the function that this MLValue is defined in.
MLFunction *getFunction();
/// Return the function that this MLValue is defined in.
const MLFunction *getFunction() const {
return const_cast<MLValue *>(this)->getFunction();
}
protected:
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
};
/// Block arguments are ML Values.
class BlockArgument : public MLValue {
public:
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::BlockArgument;
}
/// Return the function that this argument is defined in.
MLFunction *getFunction();
const MLFunction *getFunction() const {
return const_cast<BlockArgument *>(this)->getFunction();
}
StmtBlock *getOwner() { return owner; }
const StmtBlock *getOwner() const { return owner; }
private:
friend class StmtBlock; // For access to private constructor.
BlockArgument(Type type, StmtBlock *owner)
: MLValue(MLValueKind::BlockArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
StmtBlock *const owner;
};
/// This is a value defined by a result of an operation instruction.
class StmtResult : public MLValue {
public:
StmtResult(Type type, OperationStmt *owner)
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::StmtResult;
}
OperationStmt *getOwner() { return owner; }
const OperationStmt *getOwner() const { return owner; }
/// Returns the number of this result.
unsigned getResultNumber() const;
private:
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
OperationStmt *const owner;
};
// TODO(clattner) clean all this up.
using CFGValue = MLValue;
using BBArgument = BlockArgument;
using InstResult = StmtResult;
} // namespace mlir
#endif

View File

@ -27,8 +27,8 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include <type_traits> #include <type_traits>
namespace mlir { namespace mlir {
@ -107,7 +107,7 @@ template <typename OpClass> struct op_matcher {
/// Entry point for matching a pattern over an SSAValue. /// Entry point for matching a pattern over an SSAValue.
template <typename Pattern> template <typename Pattern>
inline bool matchPattern(SSAValue *value, const Pattern &pattern) { inline bool matchPattern(Value *value, const Pattern &pattern) {
// TODO: handle other cases // TODO: handle other cases
if (auto *op = value->getDefiningOperation()) if (auto *op = value->getDefiningOperation())
return const_cast<Pattern &>(pattern).match(op); return const_cast<Pattern &>(pattern).match(op);

View File

@ -29,7 +29,7 @@
#define MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/SSAValue.h" #include "mlir/IR/Value.h"
#include <type_traits> #include <type_traits>
namespace mlir { namespace mlir {
@ -78,11 +78,11 @@ public:
} }
/// If the OpType operation includes the OneResult trait, then OpPointer can /// If the OpType operation includes the OneResult trait, then OpPointer can
/// be implicitly converted to an SSAValue*. This yields the value of the /// be implicitly converted to an Value*. This yields the value of the
/// only result. /// only result.
template <typename SFINAE = OpType> template <typename SFINAE = OpType>
operator typename std::enable_if<IsSingleResult<SFINAE>::value, operator typename std::enable_if<IsSingleResult<SFINAE>::value,
SSAValue *>::type() { Value *>::type() {
return value.getResult(); return value.getResult();
} }
@ -114,14 +114,14 @@ public:
} }
/// If the OpType operation includes the OneResult trait, then OpPointer can /// If the OpType operation includes the OneResult trait, then OpPointer can
/// be implicitly converted to an const SSAValue*. This yields the value of /// be implicitly converted to an const Value*. This yields the value of
/// the only result. /// the only result.
template <typename SFINAE = OpType> template <typename SFINAE = OpType>
operator typename std::enable_if< operator typename std::enable_if<
std::is_convertible< std::is_convertible<
SFINAE *, SFINAE *,
OpTrait::OneResult<typename SFINAE::ConcreteOpType> *>::value, OpTrait::OneResult<typename SFINAE::ConcreteOpType> *>::value,
const SSAValue *>::type() const { const Value *>::type() const {
return value.getResult(); return value.getResult();
} }
@ -346,15 +346,13 @@ private:
template <typename ConcreteType> template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> { class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public: public:
const SSAValue *getOperand() const { const Value *getOperand() const {
return this->getOperation()->getOperand(0); return this->getOperation()->getOperand(0);
} }
SSAValue *getOperand() { return this->getOperation()->getOperand(0); } Value *getOperand() { return this->getOperation()->getOperand(0); }
void setOperand(SSAValue *value) { void setOperand(Value *value) { this->getOperation()->setOperand(0, value); }
this->getOperation()->setOperand(0, value);
}
static bool verifyTrait(const Operation *op) { static bool verifyTrait(const Operation *op) {
return impl::verifyOneOperand(op); return impl::verifyOneOperand(op);
@ -371,15 +369,15 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> { class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
public: public:
const SSAValue *getOperand(unsigned i) const { const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
} }
SSAValue *getOperand(unsigned i) { Value *getOperand(unsigned i) {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
} }
void setOperand(unsigned i, SSAValue *value) { void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value); this->getOperation()->setOperand(i, value);
} }
@ -402,15 +400,15 @@ public:
unsigned getNumOperands() const { unsigned getNumOperands() const {
return this->getOperation()->getNumOperands(); return this->getOperation()->getNumOperands();
} }
const SSAValue *getOperand(unsigned i) const { const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
} }
SSAValue *getOperand(unsigned i) { Value *getOperand(unsigned i) {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
} }
void setOperand(unsigned i, SSAValue *value) { void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value); this->getOperation()->setOperand(i, value);
} }
@ -453,15 +451,13 @@ public:
return this->getOperation()->getNumOperands(); return this->getOperation()->getNumOperands();
} }
const SSAValue *getOperand(unsigned i) const { const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
} }
SSAValue *getOperand(unsigned i) { Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
return this->getOperation()->getOperand(i);
}
void setOperand(unsigned i, SSAValue *value) { void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value); this->getOperation()->setOperand(i, value);
} }
@ -503,17 +499,15 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> { class OneResult : public TraitBase<ConcreteType, OneResult> {
public: public:
SSAValue *getResult() { return this->getOperation()->getResult(0); } Value *getResult() { return this->getOperation()->getResult(0); }
const SSAValue *getResult() const { const Value *getResult() const { return this->getOperation()->getResult(0); }
return this->getOperation()->getResult(0);
}
Type getType() const { return getResult()->getType(); } Type getType() const { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in /// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns /// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'. /// there are zero uses of 'this'.
void replaceAllUsesWith(SSAValue *newValue) { void replaceAllUsesWith(Value *newValue) {
getResult()->replaceAllUsesWith(newValue); getResult()->replaceAllUsesWith(newValue);
} }
@ -548,13 +542,11 @@ public:
public: public:
static unsigned getNumResults() { return N; } static unsigned getNumResults() { return N; }
const SSAValue *getResult(unsigned i) const { const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
SSAValue *getResult(unsigned i) { Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
return this->getOperation()->getResult(i);
}
Type getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) const { return getResult(i)->getType(); }
@ -574,13 +566,11 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> { class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> {
public: public:
const SSAValue *getResult(unsigned i) const { const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
SSAValue *getResult(unsigned i) { Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
return this->getOperation()->getResult(i);
}
Type getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) const { return getResult(i)->getType(); }
@ -599,13 +589,13 @@ public:
return this->getOperation()->getNumResults(); return this->getOperation()->getNumResults();
} }
const SSAValue *getResult(unsigned i) const { const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); } Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
void setResult(unsigned i, SSAValue *value) { void setResult(unsigned i, Value *value) {
this->getOperation()->setResult(i, value); this->getOperation()->setResult(i, value);
} }
@ -762,10 +752,10 @@ public:
return this->getOperation()->setSuccessor(block, index); return this->getOperation()->setSuccessor(block, index);
} }
void addSuccessorOperand(unsigned index, SSAValue *value) { void addSuccessorOperand(unsigned index, Value *value) {
return this->getOperation()->addSuccessorOperand(index, value); return this->getOperation()->addSuccessorOperand(index, value);
} }
void addSuccessorOperands(unsigned index, ArrayRef<SSAValue *> values) { void addSuccessorOperands(unsigned index, ArrayRef<Value *> values) {
return this->getOperation()->addSuccessorOperand(index, values); return this->getOperation()->addSuccessorOperand(index, values);
} }
}; };
@ -889,8 +879,8 @@ private:
// These functions are out-of-line implementations of the methods in BinaryOp, // These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated. // which avoids them being template instantiated/duplicated.
namespace impl { namespace impl {
void buildBinaryOp(Builder *builder, OperationState *result, SSAValue *lhs, void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
SSAValue *rhs); Value *rhs);
bool parseBinaryOp(OpAsmParser *parser, OperationState *result); bool parseBinaryOp(OpAsmParser *parser, OperationState *result);
void printBinaryOp(const Operation *op, OpAsmPrinter *p); void printBinaryOp(const Operation *op, OpAsmPrinter *p);
} // namespace impl } // namespace impl
@ -906,8 +896,8 @@ class BinaryOp
: public Op<ConcreteType, OpTrait::NOperands<2>::Impl, OpTrait::OneResult, : public Op<ConcreteType, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
OpTrait::SameOperandsAndResultType, Traits...> { OpTrait::SameOperandsAndResultType, Traits...> {
public: public:
static void build(Builder *builder, OperationState *result, SSAValue *lhs, static void build(Builder *builder, OperationState *result, Value *lhs,
SSAValue *rhs) { Value *rhs) {
impl::buildBinaryOp(builder, result, lhs, rhs); impl::buildBinaryOp(builder, result, lhs, rhs);
} }
static bool parse(OpAsmParser *parser, OperationState *result) { static bool parse(OpAsmParser *parser, OperationState *result) {
@ -926,7 +916,7 @@ protected:
// These functions are out-of-line implementations of the methods in CastOp, // These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated. // which avoids them being template instantiated/duplicated.
namespace impl { namespace impl {
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source, void buildCastOp(Builder *builder, OperationState *result, Value *source,
Type destType); Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result); bool parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(const Operation *op, OpAsmPrinter *p); void printCastOp(const Operation *op, OpAsmPrinter *p);
@ -942,7 +932,7 @@ template <typename ConcreteType, template <typename T> class... Traits>
class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult, class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...> { OpTrait::HasNoSideEffect, Traits...> {
public: public:
static void build(Builder *builder, OperationState *result, SSAValue *source, static void build(Builder *builder, OperationState *result, Value *source,
Type destType) { Type destType) {
impl::buildCastOp(builder, result, source, destType); impl::buildCastOp(builder, result, source, destType);
} }

View File

@ -48,7 +48,7 @@ public:
virtual raw_ostream &getStream() const = 0; virtual raw_ostream &getStream() const = 0;
/// Print implementations for various things an operation contains. /// Print implementations for various things an operation contains.
virtual void printOperand(const SSAValue *value) = 0; virtual void printOperand(const Value *value) = 0;
/// Print a comma separated list of operands. /// Print a comma separated list of operands.
template <typename ContainerType> template <typename ContainerType>
@ -95,7 +95,7 @@ private:
}; };
// Make the implementations convenient to use. // Make the implementations convenient to use.
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) { inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) {
p.printOperand(&value); p.printOperand(&value);
return p; return p;
} }
@ -119,7 +119,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) {
// even if it isn't exactly one of them. For example, we want to print // even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type& version above, not have it match this. // FunctionType with the Type& version above, not have it match this.
template <typename T, typename std::enable_if< template <typename T, typename std::enable_if<
!std::is_convertible<T &, SSAValue &>::value && !std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Type &>::value && !std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value && !std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, AffineMap &>::value, !std::is_convertible<T &, AffineMap &>::value,
@ -264,9 +264,8 @@ public:
virtual bool parseOperand(OperandType &result) = 0; virtual bool parseOperand(OperandType &result) = 0;
/// Parse a single operation successor and it's operand list. /// Parse a single operation successor and it's operand list.
virtual bool virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
parseSuccessorAndUseList(BasicBlock *&dest, SmallVectorImpl<Value *> &operands) = 0;
SmallVectorImpl<SSAValue *> &operands) = 0;
/// These are the supported delimiters around operand lists, used by /// These are the supported delimiters around operand lists, used by
/// parseOperandList. /// parseOperandList.
@ -311,13 +310,13 @@ public:
/// Resolve an operand to an SSA value, emitting an error and returning true /// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure. /// on failure.
virtual bool resolveOperand(const OperandType &operand, Type type, virtual bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) = 0; SmallVectorImpl<Value *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning /// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success. /// true on failure, or appending the results to the list on success.
/// This method should be used when all operands have the same type. /// This method should be used when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type, virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<SSAValue *> &result) { SmallVectorImpl<Value *> &result) {
for (auto elt : operands) for (auto elt : operands)
if (resolveOperand(elt, type, result)) if (resolveOperand(elt, type, result))
return true; return true;
@ -329,7 +328,7 @@ public:
/// to the list on success. /// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands, virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc, ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<SSAValue *> &result) { SmallVectorImpl<Value *> &result) {
if (operands.size() != types.size()) if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) + return emitError(loc, Twine(operands.size()) +
" operands present, but expected " + " operands present, but expected " +

View File

@ -64,21 +64,20 @@ public:
/// Return the number of operands this operation has. /// Return the number of operands this operation has.
unsigned getNumOperands() const; unsigned getNumOperands() const;
SSAValue *getOperand(unsigned idx); Value *getOperand(unsigned idx);
const SSAValue *getOperand(unsigned idx) const { const Value *getOperand(unsigned idx) const {
return const_cast<Operation *>(this)->getOperand(idx); return const_cast<Operation *>(this)->getOperand(idx);
} }
void setOperand(unsigned idx, SSAValue *value); void setOperand(unsigned idx, Value *value);
// Support non-const operand iteration. // Support non-const operand iteration.
using operand_iterator = OperandIterator<Operation, SSAValue>; using operand_iterator = OperandIterator<Operation, Value>;
operand_iterator operand_begin(); operand_iterator operand_begin();
operand_iterator operand_end(); operand_iterator operand_end();
llvm::iterator_range<operand_iterator> getOperands(); llvm::iterator_range<operand_iterator> getOperands();
// Support const operand iteration. // Support const operand iteration.
using const_operand_iterator = using const_operand_iterator = OperandIterator<const Operation, const Value>;
OperandIterator<const Operation, const SSAValue>;
const_operand_iterator operand_begin() const; const_operand_iterator operand_begin() const;
const_operand_iterator operand_end() const; const_operand_iterator operand_end() const;
llvm::iterator_range<const_operand_iterator> getOperands() const; llvm::iterator_range<const_operand_iterator> getOperands() const;
@ -87,26 +86,25 @@ public:
unsigned getNumResults() const; unsigned getNumResults() const;
/// Return the indicated result. /// Return the indicated result.
SSAValue *getResult(unsigned idx); Value *getResult(unsigned idx);
const SSAValue *getResult(unsigned idx) const { const Value *getResult(unsigned idx) const {
return const_cast<Operation *>(this)->getResult(idx); return const_cast<Operation *>(this)->getResult(idx);
} }
// Support non-const result iteration. // Support non-const result iteration.
using result_iterator = ResultIterator<Operation, SSAValue>; using result_iterator = ResultIterator<Operation, Value>;
result_iterator result_begin(); result_iterator result_begin();
result_iterator result_end(); result_iterator result_end();
llvm::iterator_range<result_iterator> getResults(); llvm::iterator_range<result_iterator> getResults();
// Support const result iteration. // Support const result iteration.
using const_result_iterator = ResultIterator<const Operation, const SSAValue>; using const_result_iterator = ResultIterator<const Operation, const Value>;
const_result_iterator result_begin() const; const_result_iterator result_begin() const;
const_result_iterator result_end() const; const_result_iterator result_end() const;
llvm::iterator_range<const_result_iterator> getResults() const; llvm::iterator_range<const_result_iterator> getResults() const;
// Support for result type iteration. // Support for result type iteration.
using result_type_iterator = using result_type_iterator = ResultTypeIterator<const Operation, const Value>;
ResultTypeIterator<const Operation, const SSAValue>;
result_type_iterator result_type_begin() const; result_type_iterator result_type_begin() const;
result_type_iterator result_type_end() const; result_type_iterator result_type_end() const;
llvm::iterator_range<result_type_iterator> getResultTypes() const; llvm::iterator_range<result_type_iterator> getResultTypes() const;

View File

@ -40,8 +40,8 @@ class OpAsmPrinter;
class Pattern; class Pattern;
class RewritePattern; class RewritePattern;
class StmtBlock; class StmtBlock;
class SSAValue;
class Type; class Type;
class Value;
using BasicBlock = StmtBlock; using BasicBlock = StmtBlock;
/// This is a vector that owns the patterns inside of it. /// This is a vector that owns the patterns inside of it.
@ -203,7 +203,7 @@ struct OperationState {
MLIRContext *const context; MLIRContext *const context;
Location location; Location location;
OperationName name; OperationName name;
SmallVector<SSAValue *, 4> operands; SmallVector<Value *, 4> operands;
/// Types of the results of this operation. /// Types of the results of this operation.
SmallVector<Type, 4> types; SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes; SmallVector<NamedAttribute, 4> attributes;
@ -218,7 +218,7 @@ public:
: context(context), location(location), name(name) {} : context(context), location(location), name(name) {}
OperationState(MLIRContext *context, Location location, StringRef name, OperationState(MLIRContext *context, Location location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type> types, ArrayRef<Value *> operands, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors = {}) ArrayRef<StmtBlock *> successors = {})
: context(context), location(location), name(name, context), : context(context), location(location), name(name, context),
@ -227,7 +227,7 @@ public:
attributes(attributes.begin(), attributes.end()), attributes(attributes.begin(), attributes.end()),
successors(successors.begin(), successors.end()) {} successors(successors.begin(), successors.end()) {}
void addOperands(ArrayRef<SSAValue *> newOperands) { void addOperands(ArrayRef<Value *> newOperands) {
assert(successors.empty() && assert(successors.empty() &&
"Non successor operands should be added first."); "Non successor operands should be added first.");
operands.append(newOperands.begin(), newOperands.end()); operands.append(newOperands.begin(), newOperands.end());
@ -247,7 +247,7 @@ public:
attributes.push_back({name, attr}); attributes.push_back({name, attr});
} }
void addSuccessor(StmtBlock *successor, ArrayRef<SSAValue *> succOperands) { void addSuccessor(StmtBlock *successor, ArrayRef<Value *> succOperands) {
successors.push_back(successor); successors.push_back(successor);
// Insert a sentinal operand to mark a barrier between successor operands. // Insert a sentinal operand to mark a barrier between successor operands.
operands.push_back(nullptr); operands.push_back(nullptr);

View File

@ -222,8 +222,8 @@ public:
/// clients can specify a list of other nodes that this replacement may make /// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will /// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well. /// remove them as well.
void replaceOp(Operation *op, ArrayRef<SSAValue *> newValues, void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<SSAValue *> valuesToRemoveIfDead = {}); ArrayRef<Value *> valuesToRemoveIfDead = {});
/// Replaces the result op with a new op that is created without verification. /// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types. /// The result values of the two ops must be the same types.
@ -237,8 +237,7 @@ public:
/// The result values of the two ops must be the same types. This allows /// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead. /// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
void replaceOpWithNewOp(Operation *op, void replaceOpWithNewOp(Operation *op, ArrayRef<Value *> valuesToRemoveIfDead,
ArrayRef<SSAValue *> valuesToRemoveIfDead,
Args... args) { Args... args) {
auto newOp = create<OpTy>(op->getLoc(), args...); auto newOp = create<OpTy>(op->getLoc(), args...);
replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(), replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(),
@ -254,7 +253,7 @@ public:
/// rewriter should remove if they are dead at this point. /// rewriter should remove if they are dead at this point.
/// ///
void updatedRootInPlace(Operation *op, void updatedRootInPlace(Operation *op,
ArrayRef<SSAValue *> valuesToRemoveIfDead = {}); ArrayRef<Value *> valuesToRemoveIfDead = {});
protected: protected:
PatternRewriter(MLIRContext *context) : Builder(context) {} PatternRewriter(MLIRContext *context) : Builder(context) {}
@ -284,9 +283,8 @@ protected:
private: private:
/// op and newOp are known to have the same number of results, replace the /// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp /// uses of op with uses of newOp
void void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead);
ArrayRef<SSAValue *> valuesToRemoveIfDead);
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,154 +0,0 @@
//===- SSAValue.h - Base of the value hierarchy -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines generic SSAValue type and manipulation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SSAVALUE_H
#define MLIR_IR_SSAVALUE_H
#include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Function;
class OperationStmt;
class Operation;
class Statement;
using Instruction = Statement;
using OperationInst = OperationStmt;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
BlockArgument, // Block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};
/// This is the common base class for all values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
class SSAValue : public IRObjectWithUseList {
public:
~SSAValue() {}
SSAValueKind getKind() const { return typeAndKind.getInt(); }
Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(SSAValue *newValue) {
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// Return the function that this SSAValue is defined in.
Function *getFunction();
/// Return the function that this SSAValue is defined in.
const Function *getFunction() const {
return const_cast<SSAValue *>(this)->getFunction();
}
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
OperationInst *getDefiningInst();
const OperationInst *getDefiningInst() const {
return const_cast<SSAValue *>(this)->getDefiningInst();
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *getDefiningStmt();
const OperationStmt *getDefiningStmt() const {
return const_cast<SSAValue *>(this)->getDefiningStmt();
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *getDefiningOperation();
const Operation *getDefiningOperation() const {
return const_cast<SSAValue *>(this)->getDefiningOperation();
}
void print(raw_ostream &os) const;
void dump() const;
protected:
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
private:
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
};
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
value.print(os);
return os;
}
/// This template unifies the implementation logic for CFGValue and MLValue
/// while providing more type-specific APIs when walking use lists etc.
///
/// IROperandTy is the concrete instance of IROperand to use (including
/// substituted template arguments).
/// IROwnerTy is the type of the owner of an IROperandTy type.
/// KindTy is the enum 'kind' discriminator that subclasses want to use.
///
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
class SSAValueImpl : public SSAValue {
public:
// Provide more specific implementations of the base class functionality.
KindTy getKind() const { return (KindTy)SSAValue::getKind(); }
using use_iterator = SSAValueUseIterator<IROperandTy, IROwnerTy>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;
inline use_iterator use_end() const;
/// Returns a range of all uses, which is useful for iterating over all uses.
inline use_range getUses() const;
protected:
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
};
// Utility functions for iterating through SSAValue uses.
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_begin() const
-> use_iterator {
return use_iterator((IROperandTy *)getFirstUse());
}
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_end() const
-> use_iterator {
return use_iterator(nullptr);
}
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::getUses() const
-> llvm::iterator_range<use_iterator> {
return {use_begin(), use_end()};
}
} // namespace mlir
#endif

View File

@ -22,8 +22,8 @@
#ifndef MLIR_IR_STATEMENT_H #ifndef MLIR_IR_STATEMENT_H
#define MLIR_IR_STATEMENT_H #define MLIR_IR_STATEMENT_H
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h" #include "llvm/ADT/ilist_node.h"
@ -84,8 +84,8 @@ public:
// This is a verbose type used by the clone method below. // This is a verbose type used by the clone method below.
using OperandMapTy = using OperandMapTy =
DenseMap<const MLValue *, MLValue *, llvm::DenseMapInfo<const MLValue *>, DenseMap<const Value *, Value *, llvm::DenseMapInfo<const Value *>,
llvm::detail::DenseMapPair<const MLValue *, MLValue *>>; llvm::detail::DenseMapPair<const Value *, Value *>>;
/// Create a deep copy of this statement, remapping any operands that use /// Create a deep copy of this statement, remapping any operands that use
/// values outside of the statement using the map that is provided (leaving /// values outside of the statement using the map that is provided (leaving
@ -136,12 +136,12 @@ public:
unsigned getNumOperands() const; unsigned getNumOperands() const;
MLValue *getOperand(unsigned idx); Value *getOperand(unsigned idx);
const MLValue *getOperand(unsigned idx) const; const Value *getOperand(unsigned idx) const;
void setOperand(unsigned idx, MLValue *value); void setOperand(unsigned idx, Value *value);
// Support non-const operand iteration. // Support non-const operand iteration.
using operand_iterator = OperandIterator<Statement, MLValue>; using operand_iterator = OperandIterator<Statement, Value>;
operand_iterator operand_begin() { return operand_iterator(this, 0); } operand_iterator operand_begin() { return operand_iterator(this, 0); }
@ -149,14 +149,13 @@ public:
return operand_iterator(this, getNumOperands()); return operand_iterator(this, getNumOperands());
} }
/// Returns an iterator on the underlying MLValue's (MLValue *). /// Returns an iterator on the underlying Value's (Value *).
llvm::iterator_range<operand_iterator> getOperands() { llvm::iterator_range<operand_iterator> getOperands() {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }
// Support const operand iteration. // Support const operand iteration.
using const_operand_iterator = using const_operand_iterator = OperandIterator<const Statement, const Value>;
OperandIterator<const Statement, const MLValue>;
const_operand_iterator operand_begin() const { const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0); return const_operand_iterator(this, 0);
@ -166,7 +165,7 @@ public:
return const_operand_iterator(this, getNumOperands()); return const_operand_iterator(this, getNumOperands());
} }
/// Returns a const iterator on the underlying MLValue's (MLValue *). /// Returns a const iterator on the underlying Value's (Value *).
llvm::iterator_range<const_operand_iterator> getOperands() const { llvm::iterator_range<const_operand_iterator> getOperands() const {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }

View File

@ -43,7 +43,7 @@ class OperationStmt final
public: public:
/// Create a new OperationStmt with the specific fields. /// Create a new OperationStmt with the specific fields.
static OperationStmt * static OperationStmt *
create(Location location, OperationName name, ArrayRef<MLValue *> operands, create(Location location, OperationName name, ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes, ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors, MLIRContext *context); ArrayRef<StmtBlock *> successors, MLIRContext *context);
@ -69,16 +69,16 @@ public:
unsigned getNumOperands() const { return numOperands; } unsigned getNumOperands() const { return numOperands; }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const { const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get(); return getStmtOperand(idx).get();
} }
void setOperand(unsigned idx, MLValue *value) { void setOperand(unsigned idx, Value *value) {
return getStmtOperand(idx).set(value); return getStmtOperand(idx).set(value);
} }
// Support non-const operand iteration. // Support non-const operand iteration.
using operand_iterator = OperandIterator<OperationStmt, MLValue>; using operand_iterator = OperandIterator<OperationStmt, Value>;
operand_iterator operand_begin() { return operand_iterator(this, 0); } operand_iterator operand_begin() { return operand_iterator(this, 0); }
@ -86,14 +86,14 @@ public:
return operand_iterator(this, getNumOperands()); return operand_iterator(this, getNumOperands());
} }
/// Returns an iterator on the underlying MLValue's (MLValue *). /// Returns an iterator on the underlying Value's (Value *).
llvm::iterator_range<operand_iterator> getOperands() { llvm::iterator_range<operand_iterator> getOperands() {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }
// Support const operand iteration. // Support const operand iteration.
using const_operand_iterator = using const_operand_iterator =
OperandIterator<const OperationStmt, const MLValue>; OperandIterator<const OperationStmt, const Value>;
const_operand_iterator operand_begin() const { const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0); return const_operand_iterator(this, 0);
@ -103,7 +103,7 @@ public:
return const_operand_iterator(this, getNumOperands()); return const_operand_iterator(this, getNumOperands());
} }
/// Returns a const iterator on the underlying MLValue's (MLValue *). /// Returns a const iterator on the underlying Value's (Value *).
llvm::iterator_range<const_operand_iterator> getOperands() const { llvm::iterator_range<const_operand_iterator> getOperands() const {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }
@ -126,11 +126,11 @@ public:
unsigned getNumResults() const { return numResults; } unsigned getNumResults() const { return numResults; }
MLValue *getResult(unsigned idx) { return &getStmtResult(idx); } Value *getResult(unsigned idx) { return &getStmtResult(idx); }
const MLValue *getResult(unsigned idx) const { return &getStmtResult(idx); } const Value *getResult(unsigned idx) const { return &getStmtResult(idx); }
// Support non-const result iteration. // Support non-const result iteration.
using result_iterator = ResultIterator<OperationStmt, MLValue>; using result_iterator = ResultIterator<OperationStmt, Value>;
result_iterator result_begin() { return result_iterator(this, 0); } result_iterator result_begin() { return result_iterator(this, 0); }
result_iterator result_end() { result_iterator result_end() {
return result_iterator(this, getNumResults()); return result_iterator(this, getNumResults());
@ -141,7 +141,7 @@ public:
// Support const result iteration. // Support const result iteration.
using const_result_iterator = using const_result_iterator =
ResultIterator<const OperationStmt, const MLValue>; ResultIterator<const OperationStmt, const Value>;
const_result_iterator result_begin() const { const_result_iterator result_begin() const {
return const_result_iterator(this, 0); return const_result_iterator(this, 0);
} }
@ -170,7 +170,7 @@ public:
// Support result type iteration. // Support result type iteration.
using result_type_iterator = using result_type_iterator =
ResultTypeIterator<const OperationStmt, const MLValue>; ResultTypeIterator<const OperationStmt, const Value>;
result_type_iterator result_type_begin() const { result_type_iterator result_type_begin() const {
return result_type_iterator(this, 0); return result_type_iterator(this, 0);
} }
@ -290,15 +290,15 @@ private:
}; };
/// For statement represents an affine loop nest. /// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue { class ForStmt : public Statement, public Value {
public: public:
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands, static ForStmt *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands, AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step); AffineMap ubMap, int64_t step);
~ForStmt() { ~ForStmt() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor // Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before the MLValue that this // since child statements need to be destroyed before the Value that this
// for stmt represents is destroyed. Affine maps are immortal objects and // for stmt represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted. // don't need to be deleted.
getBody()->clear(); getBody()->clear();
@ -308,8 +308,8 @@ public:
using Statement::getFunction; using Statement::getFunction;
/// Operand iterators. /// Operand iterators.
using operand_iterator = OperandIterator<ForStmt, MLValue>; using operand_iterator = OperandIterator<ForStmt, Value>;
using const_operand_iterator = OperandIterator<const ForStmt, const MLValue>; using const_operand_iterator = OperandIterator<const ForStmt, const Value>;
/// Operand iterator range. /// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>; using operand_range = llvm::iterator_range<operand_iterator>;
@ -340,9 +340,9 @@ public:
AffineMap getUpperBoundMap() const { return ubMap; } AffineMap getUpperBoundMap() const { return ubMap; }
/// Set lower bound. /// Set lower bound.
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap map); void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
/// Set upper bound. /// Set upper bound.
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap map); void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
/// Set the lower bound map without changing operands. /// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap map); void setLowerBoundMap(AffineMap map);
@ -385,11 +385,11 @@ public:
unsigned getNumOperands() const { return operands.size(); } unsigned getNumOperands() const { return operands.size(); }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const { const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get(); return getStmtOperand(idx).get();
} }
void setOperand(unsigned idx, MLValue *value) { void setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value); getStmtOperand(idx).set(value);
} }
@ -439,10 +439,10 @@ public:
} }
// For statement represents implicitly represents induction variable by // For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop // inheriting from Value class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself. // induction variable, just use the for statement itself.
static bool classof(const SSAValue *value) { static bool classof(const Value *value) {
return value->getKind() == SSAValueKind::ForStmt; return value->getKind() == Value::Kind::ForStmt;
} }
private: private:
@ -475,7 +475,7 @@ public:
AffineMap getMap() const { return map; } AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; } unsigned getNumOperands() const { return opEnd - opStart; }
const MLValue *getOperand(unsigned idx) const { const Value *getOperand(unsigned idx) const {
return stmt.getOperand(opStart + idx); return stmt.getOperand(opStart + idx);
} }
const StmtOperand &getStmtOperand(unsigned idx) const { const StmtOperand &getStmtOperand(unsigned idx) const {
@ -486,15 +486,15 @@ public:
using operand_range = ForStmt::operand_range; using operand_range = ForStmt::operand_range;
operand_iterator operand_begin() const { operand_iterator operand_begin() const {
// These are iterators over MLValue *. Not casting away const'ness would // These are iterators over Value *. Not casting away const'ness would
// require the caller to use const MLValue *. // require the caller to use const Value *.
return operand_iterator(const_cast<ForStmt *>(&stmt), opStart); return operand_iterator(const_cast<ForStmt *>(&stmt), opStart);
} }
operand_iterator operand_end() const { operand_iterator operand_end() const {
return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd); return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd);
} }
/// Returns an iterator on the underlying MLValue's (MLValue *). /// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands() const { return {operand_begin(), operand_end()}; } operand_range getOperands() const { return {operand_begin(), operand_end()}; }
ArrayRef<StmtOperand> getStmtOperands() const { ArrayRef<StmtOperand> getStmtOperands() const {
auto ops = stmt.getStmtOperands(); auto ops = stmt.getStmtOperands();
@ -520,7 +520,7 @@ private:
/// If statement restricts execution to a subset of the loop iteration space. /// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement { class IfStmt : public Statement {
public: public:
static IfStmt *create(Location location, ArrayRef<MLValue *> operands, static IfStmt *create(Location location, ArrayRef<Value *> operands,
IntegerSet set); IntegerSet set);
~IfStmt(); ~IfStmt();
@ -556,8 +556,8 @@ public:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Operand iterators. /// Operand iterators.
using operand_iterator = OperandIterator<IfStmt, MLValue>; using operand_iterator = OperandIterator<IfStmt, Value>;
using const_operand_iterator = OperandIterator<const IfStmt, const MLValue>; using const_operand_iterator = OperandIterator<const IfStmt, const Value>;
/// Operand iterator range. /// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>; using operand_range = llvm::iterator_range<operand_iterator>;
@ -565,11 +565,11 @@ public:
unsigned getNumOperands() const { return operands.size(); } unsigned getNumOperands() const { return operands.size(); }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); } Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const { const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get(); return getStmtOperand(idx).get();
} }
void setOperand(unsigned idx, MLValue *value) { void setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value); getStmtOperand(idx).set(value);
} }

View File

@ -26,7 +26,6 @@
namespace mlir { namespace mlir {
class IfStmt; class IfStmt;
class MLValue;
class StmtBlockList; class StmtBlockList;
using CFGFunction = Function; using CFGFunction = Function;
using MLFunction = Function; using MLFunction = Function;
@ -412,7 +411,7 @@ public:
} }
private: private:
using BBUseIterator = SSAValueUseIterator<StmtBlockOperand, OperationStmt>; using BBUseIterator = ValueUseIterator<StmtBlockOperand, OperationStmt>;
BBUseIterator bbUseIterator; BBUseIterator bbUseIterator;
}; };

View File

@ -30,7 +30,7 @@ namespace mlir {
class IROperand; class IROperand;
class IROperandOwner; class IROperandOwner;
template <typename OperandType, typename OwnerType> class SSAValueUseIterator; template <typename OperandType, typename OwnerType> class ValueUseIterator;
class IRObjectWithUseList { class IRObjectWithUseList {
public: public:
@ -44,7 +44,7 @@ public:
/// Returns true if this value has exactly one use. /// Returns true if this value has exactly one use.
inline bool hasOneUse() const; inline bool hasOneUse() const;
using use_iterator = SSAValueUseIterator<IROperand, IROperandOwner>; using use_iterator = ValueUseIterator<IROperand, IROperandOwner>;
using use_range = llvm::iterator_range<use_iterator>; using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const; inline use_iterator use_begin() const;
@ -228,33 +228,33 @@ public:
/// An iterator over all uses of a ValueBase. /// An iterator over all uses of a ValueBase.
template <typename OperandType, typename OwnerType> template <typename OperandType, typename OwnerType>
class SSAValueUseIterator class ValueUseIterator
: public std::iterator<std::forward_iterator_tag, IROperand> { : public std::iterator<std::forward_iterator_tag, IROperand> {
public: public:
SSAValueUseIterator() = default; ValueUseIterator() = default;
explicit SSAValueUseIterator(OperandType *current) : current(current) {} explicit ValueUseIterator(OperandType *current) : current(current) {}
OperandType *operator->() const { return current; } OperandType *operator->() const { return current; }
OperandType &operator*() const { return *current; } OperandType &operator*() const { return *current; }
OwnerType *getUser() const { return current->getOwner(); } OwnerType *getUser() const { return current->getOwner(); }
SSAValueUseIterator &operator++() { ValueUseIterator &operator++() {
assert(current && "incrementing past end()!"); assert(current && "incrementing past end()!");
current = (OperandType *)current->getNextOperandUsingThisValue(); current = (OperandType *)current->getNextOperandUsingThisValue();
return *this; return *this;
} }
SSAValueUseIterator operator++(int unused) { ValueUseIterator operator++(int unused) {
SSAValueUseIterator copy = *this; ValueUseIterator copy = *this;
++*this; ++*this;
return copy; return copy;
} }
friend bool operator==(SSAValueUseIterator lhs, SSAValueUseIterator rhs) { friend bool operator==(ValueUseIterator lhs, ValueUseIterator rhs) {
return lhs.current == rhs.current; return lhs.current == rhs.current;
} }
friend bool operator!=(SSAValueUseIterator lhs, SSAValueUseIterator rhs) { friend bool operator!=(ValueUseIterator lhs, ValueUseIterator rhs) {
return !(lhs == rhs); return !(lhs == rhs);
} }

View File

@ -0,0 +1,198 @@
//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines generic Value type and manipulation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_VALUE_H
#define MLIR_IR_VALUE_H
#include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Function;
class OperationStmt;
class Operation;
class Statement;
class StmtBlock;
class Value;
using Instruction = Statement;
using OperationInst = OperationStmt;
/// The operand of ML function statement contains a Value.
using StmtOperand = IROperandImpl<Value, Statement>;
/// This is the common base class for all values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
class Value : public IRObjectWithUseList {
public:
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class Kind {
BlockArgument, // block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};
~Value() {}
Kind getKind() const { return typeAndKind.getInt(); }
Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(Value *newValue) {
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// TODO: move isValidDim/isValidSymbol to a utility library specific to the
/// polyhedral operations.
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim() const;
/// Returns true if the given Value can be used as a symbol.
bool isValidSymbol() const;
/// Return the function that this Value is defined in.
Function *getFunction();
/// Return the function that this Value is defined in.
const Function *getFunction() const {
return const_cast<Value *>(this)->getFunction();
}
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
OperationInst *getDefiningInst();
const OperationInst *getDefiningInst() const {
return const_cast<Value *>(this)->getDefiningInst();
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *getDefiningStmt();
const OperationStmt *getDefiningStmt() const {
return const_cast<Value *>(this)->getDefiningStmt();
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *getDefiningOperation();
const Operation *getDefiningOperation() const {
return const_cast<Value *>(this)->getDefiningOperation();
}
using use_iterator = ValueUseIterator<StmtOperand, Statement>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;
inline use_iterator use_end() const;
/// Returns a range of all uses, which is useful for iterating over all uses.
inline use_range getUses() const;
void print(raw_ostream &os) const;
void dump() const;
protected:
Value(Kind kind, Type type) : typeAndKind(type, kind) {}
private:
const llvm::PointerIntPair<Type, 3, Kind> typeAndKind;
};
inline raw_ostream &operator<<(raw_ostream &os, const Value &value) {
value.print(os);
return os;
}
// Utility functions for iterating through Value uses.
inline auto Value::use_begin() const -> use_iterator {
return use_iterator((StmtOperand *)getFirstUse());
}
inline auto Value::use_end() const -> use_iterator {
return use_iterator(nullptr);
}
inline auto Value::getUses() const -> llvm::iterator_range<use_iterator> {
return {use_begin(), use_end()};
}
/// Block arguments are values.
class BlockArgument : public Value {
public:
static bool classof(const Value *value) {
return value->getKind() == Kind::BlockArgument;
}
/// Return the function that this argument is defined in.
Function *getFunction();
const Function *getFunction() const {
return const_cast<BlockArgument *>(this)->getFunction();
}
StmtBlock *getOwner() { return owner; }
const StmtBlock *getOwner() const { return owner; }
private:
friend class StmtBlock; // For access to private constructor.
BlockArgument(Type type, StmtBlock *owner)
: Value(Value::Kind::BlockArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
StmtBlock *const owner;
};
/// This is a value defined by a result of an operation instruction.
class StmtResult : public Value {
public:
StmtResult(Type type, OperationStmt *owner)
: Value(Value::Kind::StmtResult, type), owner(owner) {}
static bool classof(const Value *value) {
return value->getKind() == Kind::StmtResult;
}
OperationStmt *getOwner() { return owner; }
const OperationStmt *getOwner() const { return owner; }
/// Returns the number of this result.
unsigned getResultNumber() const;
private:
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
OperationStmt *const owner;
};
// TODO(clattner) clean all this up.
using BBArgument = BlockArgument;
using InstResult = StmtResult;
} // namespace mlir
#endif

View File

@ -157,14 +157,14 @@ class Op<string mnemonic, list<OpProperty> props = []> {
// //
// static void build(Builder* builder, OperationState* result, // static void build(Builder* builder, OperationState* result,
// Type resultType0, Type resultType1, ..., // Type resultType0, Type resultType1, ...,
// SSAValue* arg0, SSAValue* arg1, ..., // Value arg0, Value arg1, ...,
// Attribute <attr0-name>, Attribute <attr1-name>, ...); // Attribute <attr0-name>, Attribute <attr1-name>, ...);
// //
// * where the attributes follow the same declaration order as in the op. // * where the attributes follow the same declaration order as in the op.
// //
// static void build(Builder* builder, OperationState* result, // static void build(Builder* builder, OperationState* result,
// ArrayRef<Type> resultTypes, // ArrayRef<Type> resultTypes,
// ArrayRef<SSAValue*> args, // ArrayRef<Value> args,
// ArrayRef<NamedAttribute> attributes); // ArrayRef<NamedAttribute> attributes);
code builder = ?; code builder = ?;

View File

@ -30,7 +30,6 @@
namespace mlir { namespace mlir {
class AffineMap; class AffineMap;
class Builder; class Builder;
class MLValue;
class StandardOpsDialect : public Dialect { class StandardOpsDialect : public Dialect {
public: public:
@ -48,8 +47,8 @@ class AddFOp
: public BinaryOp<AddFOp, OpTrait::ResultsAreFloatLike, : public BinaryOp<AddFOp, OpTrait::ResultsAreFloatLike,
OpTrait::IsCommutative, OpTrait::HasNoSideEffect> { OpTrait::IsCommutative, OpTrait::HasNoSideEffect> {
public: public:
static void build(Builder *builder, OperationState *result, SSAValue *lhs, static void build(Builder *builder, OperationState *result, Value *lhs,
SSAValue *rhs); Value *rhs);
static StringRef getOperationName() { return "addf"; } static StringRef getOperationName() { return "addf"; }
@ -116,7 +115,7 @@ public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<SSAValue *> operands = {}); MemRefType memrefType, ArrayRef<Value *> operands = {});
bool verify() const; bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
@ -140,7 +139,7 @@ public:
static StringRef getOperationName() { return "call"; } static StringRef getOperationName() { return "call"; }
static void build(Builder *builder, OperationState *result, Function *callee, static void build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands); ArrayRef<Value *> operands);
Function *getCallee() const { Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee").getValue(); return getAttrOfType<FunctionAttr>("callee").getValue();
@ -169,11 +168,11 @@ class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
public: public:
static StringRef getOperationName() { return "call_indirect"; } static StringRef getOperationName() { return "call_indirect"; }
static void build(Builder *builder, OperationState *result, SSAValue *callee, static void build(Builder *builder, OperationState *result, Value *callee,
ArrayRef<SSAValue *> operands); ArrayRef<Value *> operands);
const SSAValue *getCallee() const { return getOperand(0); } const Value *getCallee() const { return getOperand(0); }
SSAValue *getCallee() { return getOperand(0); } Value *getCallee() { return getOperand(0); }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
@ -240,7 +239,7 @@ public:
static CmpIPredicate getPredicateByName(StringRef name); static CmpIPredicate getPredicateByName(StringRef name);
static void build(Builder *builder, OperationState *result, CmpIPredicate, static void build(Builder *builder, OperationState *result, CmpIPredicate,
SSAValue *lhs, SSAValue *rhs); Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
bool verify() const; bool verify() const;
@ -263,14 +262,14 @@ private:
class DeallocOp class DeallocOp
: public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> { : public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public: public:
SSAValue *getMemRef() { return getOperand(); } Value *getMemRef() { return getOperand(); }
const SSAValue *getMemRef() const { return getOperand(); } const Value *getMemRef() const { return getOperand(); }
void setMemRef(SSAValue *value) { setOperand(value); } void setMemRef(Value *value) { setOperand(value); }
static StringRef getOperationName() { return "dealloc"; } static StringRef getOperationName() { return "dealloc"; }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, SSAValue *memref); static void build(Builder *builder, OperationState *result, Value *memref);
bool verify() const; bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
@ -292,7 +291,7 @@ class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect> { OpTrait::HasNoSideEffect> {
public: public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index); Value *memrefOrTensor, unsigned index);
Attribute constantFold(ArrayRef<Attribute> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
@ -354,15 +353,15 @@ private:
class DmaStartOp class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public: public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *srcMemRef,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices, ArrayRef<Value *> srcIndices, Value *destMemRef,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices, ArrayRef<Value *> destIndices, Value *numElements,
SSAValue *numElements, SSAValue *tagMemRef, Value *tagMemRef, ArrayRef<Value *> tagIndices,
ArrayRef<SSAValue *> tagIndices, SSAValue *stride = nullptr, Value *stride = nullptr,
SSAValue *elementsPerStride = nullptr); Value *elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation. // Returns the source MemRefType for this DMA operation.
const SSAValue *getSrcMemRef() const { return getOperand(0); } const Value *getSrcMemRef() const { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType. // Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() const { unsigned getSrcMemRefRank() const {
return getSrcMemRef()->getType().cast<MemRefType>().getRank(); return getSrcMemRef()->getType().cast<MemRefType>().getRank();
@ -375,7 +374,7 @@ public:
} }
// Returns the destination MemRefType for this DMA operations. // Returns the destination MemRefType for this DMA operations.
const SSAValue *getDstMemRef() const { const Value *getDstMemRef() const {
return getOperand(1 + getSrcMemRefRank()); return getOperand(1 + getSrcMemRefRank());
} }
// Returns the rank (number of indices) of the destination MemRefType. // Returns the rank (number of indices) of the destination MemRefType.
@ -398,12 +397,12 @@ public:
} }
// Returns the number of elements being transferred by this DMA operation. // Returns the number of elements being transferred by this DMA operation.
const SSAValue *getNumElements() const { const Value *getNumElements() const {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
} }
// Returns the Tag MemRef for this DMA operation. // Returns the Tag MemRef for this DMA operation.
const SSAValue *getTagMemRef() const { const Value *getTagMemRef() const {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
} }
// Returns the rank (number of indices) of the tag MemRefType. // Returns the rank (number of indices) of the tag MemRefType.
@ -453,21 +452,21 @@ public:
1 + 1 + getTagMemRefRank(); 1 + 1 + getTagMemRefRank();
} }
SSAValue *getStride() { Value *getStride() {
if (!isStrided()) if (!isStrided())
return nullptr; return nullptr;
return getOperand(getNumOperands() - 1 - 1); return getOperand(getNumOperands() - 1 - 1);
} }
const SSAValue *getStride() const { const Value *getStride() const {
return const_cast<DmaStartOp *>(this)->getStride(); return const_cast<DmaStartOp *>(this)->getStride();
} }
SSAValue *getNumElementsPerStride() { Value *getNumElementsPerStride() {
if (!isStrided()) if (!isStrided())
return nullptr; return nullptr;
return getOperand(getNumOperands() - 1); return getOperand(getNumOperands() - 1);
} }
const SSAValue *getNumElementsPerStride() const { const Value *getNumElementsPerStride() const {
return const_cast<DmaStartOp *>(this)->getNumElementsPerStride(); return const_cast<DmaStartOp *>(this)->getNumElementsPerStride();
} }
@ -493,15 +492,14 @@ protected:
class DmaWaitOp class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> { : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public: public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *tagMemRef,
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices, ArrayRef<Value *> tagIndices, Value *numElements);
SSAValue *numElements);
static StringRef getOperationName() { return "dma_wait"; } static StringRef getOperationName() { return "dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on. // Returns the Tag MemRef associated with the DMA operation being waited on.
const SSAValue *getTagMemRef() const { return getOperand(0); } const Value *getTagMemRef() const { return getOperand(0); }
SSAValue *getTagMemRef() { return getOperand(0); } Value *getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation. // Returns the tag memref index for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator> llvm::iterator_range<Operation::const_operand_iterator>
@ -516,7 +514,7 @@ public:
} }
// Returns the number of elements transferred in the associated DMA operation. // Returns the number of elements transferred in the associated DMA operation.
const SSAValue *getNumElements() const { const Value *getNumElements() const {
return getOperand(1 + getTagMemRefRank()); return getOperand(1 + getTagMemRefRank());
} }
@ -545,11 +543,11 @@ class ExtractElementOp
: public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult, : public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult,
OpTrait::HasNoSideEffect> { OpTrait::HasNoSideEffect> {
public: public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *aggregate,
SSAValue *aggregate, ArrayRef<SSAValue *> indices = {}); ArrayRef<Value *> indices = {});
SSAValue *getAggregate() { return getOperand(0); } Value *getAggregate() { return getOperand(0); }
const SSAValue *getAggregate() const { return getOperand(0); } const Value *getAggregate() const { return getOperand(0); }
llvm::iterator_range<Operation::operand_iterator> getIndices() { llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
@ -583,12 +581,12 @@ class LoadOp
: public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> { : public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public: public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, SSAValue *memref, static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<SSAValue *> indices = {}); ArrayRef<Value *> indices = {});
SSAValue *getMemRef() { return getOperand(0); } Value *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); } const Value *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); } void setMemRef(Value *value) { setOperand(0, value); }
MemRefType getMemRefType() const { MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>(); return getMemRef()->getType().cast<MemRefType>();
} }
@ -705,19 +703,18 @@ class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> { OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public: public:
static StringRef getOperationName() { return "select"; } static StringRef getOperationName() { return "select"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *condition,
SSAValue *condition, SSAValue *trueValue, Value *trueValue, Value *falseValue);
SSAValue *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
bool verify() const; bool verify() const;
SSAValue *getCondition() { return getOperand(0); } Value *getCondition() { return getOperand(0); }
const SSAValue *getCondition() const { return getOperand(0); } const Value *getCondition() const { return getOperand(0); }
SSAValue *getTrueValue() { return getOperand(1); } Value *getTrueValue() { return getOperand(1); }
const SSAValue *getTrueValue() const { return getOperand(1); } const Value *getTrueValue() const { return getOperand(1); }
SSAValue *getFalseValue() { return getOperand(2); } Value *getFalseValue() { return getOperand(2); }
const SSAValue *getFalseValue() const { return getOperand(2); } const Value *getFalseValue() const { return getOperand(2); }
Attribute constantFold(ArrayRef<Attribute> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
@ -742,15 +739,15 @@ class StoreOp
public: public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
SSAValue *valueToStore, SSAValue *memref, Value *valueToStore, Value *memref,
ArrayRef<SSAValue *> indices = {}); ArrayRef<Value *> indices = {});
SSAValue *getValueToStore() { return getOperand(0); } Value *getValueToStore() { return getOperand(0); }
const SSAValue *getValueToStore() const { return getOperand(0); } const Value *getValueToStore() const { return getOperand(0); }
SSAValue *getMemRef() { return getOperand(1); } Value *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); } const Value *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); } void setMemRef(Value *value) { setOperand(1, value); }
MemRefType getMemRefType() const { MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>(); return getMemRef()->getType().cast<MemRefType>();
} }

View File

@ -98,26 +98,24 @@ public:
static StringRef getOperationName() { return "vector_transfer_read"; } static StringRef getOperationName() { return "vector_transfer_read"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
VectorType vectorType, SSAValue *srcMemRef, VectorType vectorType, Value *srcMemRef,
ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap, ArrayRef<Value *> srcIndices, AffineMap permutationMap,
Optional<SSAValue *> paddingValue = None); Optional<Value *> paddingValue = None);
VectorType getResultType() const { VectorType getResultType() const {
return getResult()->getType().cast<VectorType>(); return getResult()->getType().cast<VectorType>();
} }
SSAValue *getVector() { return getResult(); } Value *getVector() { return getResult(); }
const SSAValue *getVector() const { return getResult(); } const Value *getVector() const { return getResult(); }
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const SSAValue *getMemRef() const { const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
return getOperand(Offsets::MemRefOffset);
}
VectorType getVectorType() const { return getResultType(); } VectorType getVectorType() const { return getResultType(); }
MemRefType getMemRefType() const { MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>(); return getMemRef()->getType().cast<MemRefType>();
} }
llvm::iterator_range<Operation::operand_iterator> getIndices(); llvm::iterator_range<Operation::operand_iterator> getIndices();
llvm::iterator_range<Operation::const_operand_iterator> getIndices() const; llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
Optional<SSAValue *> getPaddingValue(); Optional<Value *> getPaddingValue();
Optional<const SSAValue *> getPaddingValue() const; Optional<const Value *> getPaddingValue() const;
AffineMap getPermutationMap() const; AffineMap getPermutationMap() const;
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
@ -169,20 +167,16 @@ class VectorTransferWriteOp
public: public:
static StringRef getOperationName() { return "vector_transfer_write"; } static StringRef getOperationName() { return "vector_transfer_write"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; } static StringRef getPermutationMapAttrName() { return "permutation_map"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *srcVector,
SSAValue *srcVector, SSAValue *dstMemRef, Value *dstMemRef, ArrayRef<Value *> dstIndices,
ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap); AffineMap permutationMap);
SSAValue *getVector() { return getOperand(Offsets::VectorOffset); } Value *getVector() { return getOperand(Offsets::VectorOffset); }
const SSAValue *getVector() const { const Value *getVector() const { return getOperand(Offsets::VectorOffset); }
return getOperand(Offsets::VectorOffset);
}
VectorType getVectorType() const { VectorType getVectorType() const {
return getVector()->getType().cast<VectorType>(); return getVector()->getType().cast<VectorType>();
} }
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const SSAValue *getMemRef() const { const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
return getOperand(Offsets::MemRefOffset);
}
MemRefType getMemRefType() const { MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>(); return getMemRef()->getType().cast<MemRefType>();
} }
@ -212,8 +206,8 @@ class VectorTypeCastOp
: public Op<VectorTypeCastOp, OpTrait::OneOperand, OpTrait::OneResult> { : public Op<VectorTypeCastOp, OpTrait::OneOperand, OpTrait::OneResult> {
public: public:
static StringRef getOperationName() { return "vector_type_cast"; } static StringRef getOperationName() { return "vector_type_cast"; }
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result, Value *srcVector,
SSAValue *srcVector, Type dstType); Type dstType);
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
bool verify() const; bool verify() const;

View File

@ -35,10 +35,9 @@ namespace mlir {
class ForStmt; class ForStmt;
class FuncBuilder; class FuncBuilder;
class Location; class Location;
class MLValue;
class Module; class Module;
class OperationStmt; class OperationStmt;
class SSAValue;
class Function; class Function;
using CFGFunction = Function; using CFGFunction = Function;
@ -52,12 +51,12 @@ using CFGFunction = Function;
/// Returns true on success and false if the replacement is not possible /// Returns true on success and false if the replacement is not possible
/// (whenever a memref is used as an operand in a non-deferencing scenario). See /// (whenever a memref is used as an operand in a non-deferencing scenario). See
/// comments at function definition for an example. /// comments at function definition for an example.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily // TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
// extended to add additional indices at any position. // extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef, bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<SSAValue *> extraIndices = {}, ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap::Null(), AffineMap indexRemap = AffineMap::Null(),
ArrayRef<SSAValue *> extraOperands = {}, ArrayRef<Value *> extraOperands = {},
const Statement *domStmtFilter = nullptr); const Statement *domStmtFilter = nullptr);
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
@ -69,9 +68,9 @@ bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
/// parameter 'results'. Returns the affine apply op created. /// parameter 'results'. Returns the affine apply op created.
OperationStmt * OperationStmt *
createComposedAffineApplyOp(FuncBuilder *builder, Location loc, createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands, ArrayRef<Value *> operands,
ArrayRef<OperationStmt *> affineApplyOps, ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> *results); SmallVectorImpl<Value *> *results);
/// Given an operation statement, inserts a new single affine apply operation, /// Given an operation statement, inserts a new single affine apply operation,
/// that is exclusively used by this operation statement, and that provides all /// that is exclusively used by this operation statement, and that provides all
@ -104,7 +103,7 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt);
/// Forward substitutes results from 'AffineApplyOp' into any users which /// Forward substitutes results from 'AffineApplyOp' into any users which
/// are also AffineApplyOps. /// are also AffineApplyOps.
// NOTE: This method may modify users of results of this operation. // NOTE: This method may modify users of results of this operation.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. // TODO(mlir-team): extend this for Value/ CFGFunctions.
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp); void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible. /// Folds the lower and upper bounds of a 'for' stmt to constants if possible.

View File

@ -489,11 +489,11 @@ bool mlir::getFlattenedAffineExprs(
// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
// the AffineApplyOp into any user AffineApplyOps. // the AffineApplyOp into any user AffineApplyOps.
void mlir::getReachableAffineApplyOps( void mlir::getReachableAffineApplyOps(
ArrayRef<MLValue *> operands, ArrayRef<Value *> operands,
SmallVectorImpl<OperationStmt *> &affineApplyOps) { SmallVectorImpl<OperationStmt *> &affineApplyOps) {
struct State { struct State {
// The ssa value for this node in the DFS traversal. // The ssa value for this node in the DFS traversal.
MLValue *value; Value *value;
// The operand index of 'value' to explore next during DFS traversal. // The operand index of 'value' to explore next during DFS traversal.
unsigned operandIndex; unsigned operandIndex;
}; };
@ -557,8 +557,8 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// setExprStride(ArrayRef<int64_t> expr, int64_t stride) // setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts, bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain) { FlatAffineConstraints *domain) {
SmallVector<MLValue *, 4> indices(forStmts.begin(), forStmts.end()); SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
// Reset while associated MLValues in 'indices' to the domain. // Reset while associated Values in 'indices' to the domain.
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forStmt : forStmts) { for (auto *forStmt : forStmts) {
// Add constraints from forStmt's bounds. // Add constraints from forStmt's bounds.
@ -583,10 +583,10 @@ static bool getStmtIndexSet(const Statement *stmt,
return getIndexSet(loops, indexSet); return getIndexSet(loops, indexSet);
} }
// ValuePositionMap manages the mapping from MLValues which represent dimension // ValuePositionMap manages the mapping from Values which represent dimension
// and symbol identifiers from 'src' and 'dst' access functions to positions // and symbol identifiers from 'src' and 'dst' access functions to positions
// in new space where some MLValues are kept separate (using addSrc/DstValue) // in new space where some Values are kept separate (using addSrc/DstValue)
// and some MLValues are merged (addSymbolValue). // and some Values are merged (addSymbolValue).
// Position lookups return the absolute position in the new space which // Position lookups return the absolute position in the new space which
// has the following format: // has the following format:
// //
@ -595,7 +595,7 @@ static bool getStmtIndexSet(const Statement *stmt,
// Note: access function non-IV dimension identifiers (that have 'dimension' // Note: access function non-IV dimension identifiers (that have 'dimension'
// positions in the access function position space) are assigned as symbols // positions in the access function position space) are assigned as symbols
// in the output position space. Convienience access functions which lookup // in the output position space. Convienience access functions which lookup
// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
// the common case of resolving positions for all access function operands. // the common case of resolving positions for all access function operands.
// //
// TODO(andydavis) Generalize this: could take a template parameter for // TODO(andydavis) Generalize this: could take a template parameter for
@ -603,25 +603,25 @@ static bool getStmtIndexSet(const Statement *stmt,
// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
class ValuePositionMap { class ValuePositionMap {
public: public:
void addSrcValue(const MLValue *value) { void addSrcValue(const Value *value) {
if (addValueAt(value, &srcDimPosMap, numSrcDims)) if (addValueAt(value, &srcDimPosMap, numSrcDims))
++numSrcDims; ++numSrcDims;
} }
void addDstValue(const MLValue *value) { void addDstValue(const Value *value) {
if (addValueAt(value, &dstDimPosMap, numDstDims)) if (addValueAt(value, &dstDimPosMap, numDstDims))
++numDstDims; ++numDstDims;
} }
void addSymbolValue(const MLValue *value) { void addSymbolValue(const Value *value) {
if (addValueAt(value, &symbolPosMap, numSymbols)) if (addValueAt(value, &symbolPosMap, numSymbols))
++numSymbols; ++numSymbols;
} }
unsigned getSrcDimOrSymPos(const MLValue *value) const { unsigned getSrcDimOrSymPos(const Value *value) const {
return getDimOrSymPos(value, srcDimPosMap, 0); return getDimOrSymPos(value, srcDimPosMap, 0);
} }
unsigned getDstDimOrSymPos(const MLValue *value) const { unsigned getDstDimOrSymPos(const Value *value) const {
return getDimOrSymPos(value, dstDimPosMap, numSrcDims); return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
} }
unsigned getSymPos(const MLValue *value) const { unsigned getSymPos(const Value *value) const {
auto it = symbolPosMap.find(value); auto it = symbolPosMap.find(value);
assert(it != symbolPosMap.end()); assert(it != symbolPosMap.end());
return numSrcDims + numDstDims + it->second; return numSrcDims + numDstDims + it->second;
@ -633,8 +633,7 @@ public:
unsigned getNumSymbols() const { return numSymbols; } unsigned getNumSymbols() const { return numSymbols; }
private: private:
bool addValueAt(const MLValue *value, bool addValueAt(const Value *value, DenseMap<const Value *, unsigned> *posMap,
DenseMap<const MLValue *, unsigned> *posMap,
unsigned position) { unsigned position) {
auto it = posMap->find(value); auto it = posMap->find(value);
if (it == posMap->end()) { if (it == posMap->end()) {
@ -643,8 +642,8 @@ private:
} }
return false; return false;
} }
unsigned getDimOrSymPos(const MLValue *value, unsigned getDimOrSymPos(const Value *value,
const DenseMap<const MLValue *, unsigned> &dimPosMap, const DenseMap<const Value *, unsigned> &dimPosMap,
unsigned dimPosOffset) const { unsigned dimPosOffset) const {
auto it = dimPosMap.find(value); auto it = dimPosMap.find(value);
if (it != dimPosMap.end()) { if (it != dimPosMap.end()) {
@ -658,25 +657,25 @@ private:
unsigned numSrcDims = 0; unsigned numSrcDims = 0;
unsigned numDstDims = 0; unsigned numDstDims = 0;
unsigned numSymbols = 0; unsigned numSymbols = 0;
DenseMap<const MLValue *, unsigned> srcDimPosMap; DenseMap<const Value *, unsigned> srcDimPosMap;
DenseMap<const MLValue *, unsigned> dstDimPosMap; DenseMap<const Value *, unsigned> dstDimPosMap;
DenseMap<const MLValue *, unsigned> symbolPosMap; DenseMap<const Value *, unsigned> symbolPosMap;
}; };
// Builds a map from MLValue to identifier position in a new merged identifier // Builds a map from Value to identifier position in a new merged identifier
// list, which is the result of merging dim/symbol lists from src/dst // list, which is the result of merging dim/symbol lists from src/dst
// iteration domains. The format of the new merged list is as follows: // iteration domains. The format of the new merged list is as follows:
// //
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers] // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
// //
// This method populates 'valuePosMap' with mappings from operand MLValues in // This method populates 'valuePosMap' with mappings from operand Values in
// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
// to the position of these values in the merged list. // to the position of these values in the merged list.
static void buildDimAndSymbolPositionMaps( static void buildDimAndSymbolPositionMaps(
const FlatAffineConstraints &srcDomain, const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) { const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) {
auto updateValuePosMap = [&](ArrayRef<MLValue *> values, bool isSrc) { auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) { for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i]; auto *value = values[i];
if (!isa<ForStmt>(values[i])) if (!isa<ForStmt>(values[i]))
@ -688,7 +687,7 @@ static void buildDimAndSymbolPositionMaps(
} }
}; };
SmallVector<MLValue *, 4> srcValues, destValues; SmallVector<Value *, 4> srcValues, destValues;
srcDomain.getIdValues(&srcValues); srcDomain.getIdValues(&srcValues);
dstDomain.getIdValues(&destValues); dstDomain.getIdValues(&destValues);
@ -702,17 +701,10 @@ static void buildDimAndSymbolPositionMaps(
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
} }
static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
const MLValue *value) {
auto it = posMap.find(value);
assert(it != posMap.end());
return it->second;
}
// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
// 'dependenceDomain'. // 'dependenceDomain'.
// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
// srcDomain/dstDomain MLValue maps. // srcDomain/dstDomain Value maps.
static void addDomainConstraints(const FlatAffineConstraints &srcDomain, static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, const FlatAffineConstraints &dstDomain,
const ValuePositionMap &valuePosMap, const ValuePositionMap &valuePosMap,
@ -790,10 +782,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
unsigned numResults = srcMap.getNumResults(); unsigned numResults = srcMap.getNumResults();
unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands(); ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands(); ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
std::vector<SmallVector<int64_t, 8>> srcFlatExprs; std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
std::vector<SmallVector<int64_t, 8>> destFlatExprs; std::vector<SmallVector<int64_t, 8>> destFlatExprs;
@ -848,7 +840,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
} }
// Add equality constraints for any operands that are defined by constant ops. // Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const MLValue *> operands) { auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) { for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (isa<ForStmt>(operands[i])) if (isa<ForStmt>(operands[i]))
continue; continue;
@ -1095,7 +1087,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// upper/lower loop bounds for each ForStmt in the loop nest associated // upper/lower loop bounds for each ForStmt in the loop nest associated
// with each access. // with each access.
// *) Build dimension and symbol position maps for each access, which map // *) Build dimension and symbol position maps for each access, which map
// MLValues from access functions and iteration domains to their position // Values from access functions and iteration domains to their position
// in the merged constraint system built by this method. // in the merged constraint system built by this method.
// //
// This method builds a constraint system with the following column format: // This method builds a constraint system with the following column format:
@ -1202,7 +1194,7 @@ bool mlir::checkMemrefAccessDependence(
return false; return false;
} }
// Build dim and symbol position maps for each access from access operand // Build dim and symbol position maps for each access from access operand
// MLValue to position in merged contstraint system. // Value to position in merged contstraint system.
ValuePositionMap valuePosMap; ValuePositionMap valuePosMap;
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
dstAccessMap, &valuePosMap); dstAccessMap, &valuePosMap);

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Statements.h" #include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
@ -238,23 +237,23 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
AffineValueMap::AffineValueMap(const AffineApplyOp &op) AffineValueMap::AffineValueMap(const AffineApplyOp &op)
: map(op.getAffineMap()) { : map(op.getAffineMap()) {
for (auto *operand : op.getOperands()) for (auto *operand : op.getOperands())
operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand))); operands.push_back(const_cast<Value *>(operand));
for (unsigned i = 0, e = op.getNumResults(); i < e; i++) for (unsigned i = 0, e = op.getNumResults(); i < e; i++)
results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i)))); results.push_back(const_cast<Value *>(op.getResult(i)));
} }
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands) AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands)
: map(map) { : map(map) {
for (MLValue *operand : operands) { for (Value *operand : operands) {
this->operands.push_back(operand); this->operands.push_back(operand);
} }
} }
void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) { void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands) {
this->operands.clear(); this->operands.clear();
this->results.clear(); this->results.clear();
this->map.reset(map); this->map.reset(map);
for (MLValue *operand : operands) { for (Value *operand : operands) {
this->operands.push_back(operand); this->operands.push_back(operand);
} }
} }
@ -275,7 +274,7 @@ void AffineValueMap::forwardSubstituteSingle(const AffineApplyOp &inputOp,
// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
static bool findIndex(MLValue *valueToMatch, ArrayRef<MLValue *> valuesToSearch, static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
unsigned indexStart, unsigned *indexOfMatch) { unsigned indexStart, unsigned *indexOfMatch) {
unsigned size = valuesToSearch.size(); unsigned size = valuesToSearch.size();
for (unsigned i = indexStart; i < size; ++i) { for (unsigned i = indexStart; i < size; ++i) {
@ -324,8 +323,7 @@ void AffineValueMap::forwardSubstitute(
for (unsigned j = 0; j < inputNumResults; ++j) { for (unsigned j = 0; j < inputNumResults; ++j) {
if (!inputResultsToSubstitute[j]) if (!inputResultsToSubstitute[j])
continue; continue;
if (operands[i] == if (operands[i] == const_cast<Value *>(inputOp.getResult(j))) {
cast<MLValue>(const_cast<SSAValue *>(inputOp.getResult(j)))) {
currOperandToInputResult[i] = j; currOperandToInputResult[i] = j;
inputResultsUsed.insert(j); inputResultsUsed.insert(j);
} }
@ -365,7 +363,7 @@ void AffineValueMap::forwardSubstitute(
} }
// Build new output operands list and map update. // Build new output operands list and map update.
SmallVector<MLValue *, 4> outputOperands; SmallVector<Value *, 4> outputOperands;
unsigned outputOperandPosition = 0; unsigned outputOperandPosition = 0;
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults()); AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults());
@ -385,8 +383,7 @@ void AffineValueMap::forwardSubstitute(
if (inputPositionsUsed.count(i) == 0) if (inputPositionsUsed.count(i) == 0)
continue; continue;
// Check if input operand has a dup in current operand list. // Check if input operand has a dup in current operand list.
auto *inputOperand = auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
unsigned outputIndex; unsigned outputIndex;
if (findIndex(inputOperand, outputOperands, /*indexStart=*/0, if (findIndex(inputOperand, outputOperands, /*indexStart=*/0,
&outputIndex)) { &outputIndex)) {
@ -418,8 +415,7 @@ void AffineValueMap::forwardSubstitute(
continue; continue;
unsigned inputSymbolPosition = i - inputNumDims; unsigned inputSymbolPosition = i - inputNumDims;
// Check if input operand has a dup in current operand list. // Check if input operand has a dup in current operand list.
auto *inputOperand = auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
// Find output operand index of 'inputOperand' dup. // Find output operand index of 'inputOperand' dup.
unsigned outputIndex; unsigned outputIndex;
// Start at index 'outputNumDims' so that only symbol operands are searched. // Start at index 'outputNumDims' so that only symbol operands are searched.
@ -451,7 +447,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
/// This method uses the invariant that operands are always positionally aligned /// This method uses the invariant that operands are always positionally aligned
/// with the AffineDimExpr in the underlying AffineMap. /// with the AffineDimExpr in the underlying AffineMap.
bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const { bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
unsigned index; unsigned index;
findIndex(value, operands, /*indexStart=*/0, &index); findIndex(value, operands, /*indexStart=*/0, &index);
auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx); auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
@ -460,12 +456,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
return expr.isFunctionOfDim(index); return expr.isFunctionOfDim(index);
} }
SSAValue *AffineValueMap::getOperand(unsigned i) const { Value *AffineValueMap::getOperand(unsigned i) const {
return static_cast<SSAValue *>(operands[i]); return static_cast<Value *>(operands[i]);
} }
ArrayRef<MLValue *> AffineValueMap::getOperands() const { ArrayRef<Value *> AffineValueMap::getOperands() const {
return ArrayRef<MLValue *>(operands); return ArrayRef<Value *>(operands);
} }
AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
@ -546,7 +542,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
unsigned newNumReservedCols, unsigned newNumReservedCols,
unsigned newNumDims, unsigned newNumSymbols, unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals, unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) { ArrayRef<Value *> idArgs) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column"); "minimum 1 column");
numReservedCols = newNumReservedCols; numReservedCols = newNumReservedCols;
@ -570,7 +566,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals, unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) { ArrayRef<Value *> idArgs) {
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
newNumSymbols, newNumLocals, idArgs); newNumSymbols, newNumLocals, idArgs);
} }
@ -597,17 +593,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) {
addId(IdKind::Local, pos); addId(IdKind::Local, pos);
} }
void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) { void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
addId(IdKind::Dimension, pos, id); addId(IdKind::Dimension, pos, id);
} }
void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) { void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
addId(IdKind::Symbol, pos, id); addId(IdKind::Symbol, pos, id);
} }
/// Adds a dimensional identifier. The added column is initialized to /// Adds a dimensional identifier. The added column is initialized to
/// zero. /// zero.
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) { void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
if (kind == IdKind::Dimension) { if (kind == IdKind::Dimension) {
assert(pos <= getNumDimIds()); assert(pos <= getNumDimIds());
} else if (kind == IdKind::Symbol) { } else if (kind == IdKind::Symbol) {
@ -755,7 +751,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
// Dims and symbols. // Dims and symbols.
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
unsigned loc; unsigned loc;
bool ret = findId(*cast<MLValue>(vMap->getOperand(i)), &loc); bool ret = findId(*vMap->getOperand(i), &loc);
assert(ret && "value map's id can't be found"); assert(ret && "value map's id can't be found");
(void)ret; (void)ret;
// We need to negate 'eq[r]' since the newly added dimension is going to // We need to negate 'eq[r]' since the newly added dimension is going to
@ -1231,7 +1227,7 @@ void FlatAffineConstraints::addUpperBound(ArrayRef<int64_t> expr,
} }
} }
bool FlatAffineConstraints::findId(const MLValue &id, unsigned *pos) const { bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const {
unsigned i = 0; unsigned i = 0;
for (const auto &mayBeId : ids) { for (const auto &mayBeId : ids) {
if (mayBeId.hasValue() && mayBeId.getValue() == &id) { if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
@ -1253,8 +1249,8 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) { bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
unsigned pos; unsigned pos;
// Pre-condition for this method. // Pre-condition for this method.
if (!findId(*cast<MLValue>(&forStmt), &pos)) { if (!findId(forStmt, &pos)) {
assert(0 && "MLValue not found"); assert(0 && "Value not found");
return false; return false;
} }
@ -1270,7 +1266,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
unsigned loc; unsigned loc;
if (!findId(*operand, &loc)) { if (!findId(*operand, &loc)) {
if (operand->isValidSymbol()) { if (operand->isValidSymbol()) {
addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand)); addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1; loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant. // Check if the symbol is a constant.
if (auto *opStmt = operand->getDefiningStmt()) { if (auto *opStmt = operand->getDefiningStmt()) {
@ -1279,7 +1275,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
} }
} }
} else { } else {
addDimId(getNumDimIds(), const_cast<MLValue *>(operand)); addDimId(getNumDimIds(), const_cast<Value *>(operand));
loc = getNumDimIds() - 1; loc = getNumDimIds() - 1;
} }
} }
@ -1352,7 +1348,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
/// Sets the specified identifer to a constant value; asserts if the id is not /// Sets the specified identifer to a constant value; asserts if the id is not
/// found. /// found.
void FlatAffineConstraints::setIdToConstant(const MLValue &id, int64_t val) { void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) {
unsigned pos; unsigned pos;
if (!findId(id, &pos)) if (!findId(id, &pos))
// This is a pre-condition for this method. // This is a pre-condition for this method.
@ -1572,7 +1568,7 @@ void FlatAffineConstraints::print(raw_ostream &os) const {
if (ids[i] == None) if (ids[i] == None)
os << "None "; os << "None ";
else else
os << "MLValue "; os << "Value ";
} }
os << " const)\n"; os << " const)\n";
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
@ -1779,7 +1775,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
unsigned newNumDims = dimsSymbols.first; unsigned newNumDims = dimsSymbols.first;
unsigned newNumSymbols = dimsSymbols.second; unsigned newNumSymbols = dimsSymbols.second;
SmallVector<Optional<MLValue *>, 8> newIds; SmallVector<Optional<Value *>, 8> newIds;
newIds.reserve(numIds - 1); newIds.reserve(numIds - 1);
newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos); newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos);
newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end()); newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end());
@ -1942,7 +1938,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
normalizeConstraintsByGCD(); normalizeConstraintsByGCD();
} }
void FlatAffineConstraints::projectOut(MLValue *id) { void FlatAffineConstraints::projectOut(Value *id) {
unsigned pos; unsigned pos;
bool ret = findId(*id, &pos); bool ret = findId(*id, &pos);
assert(ret); assert(ret);

View File

@ -70,7 +70,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a,
} }
/// Return true if value A properly dominates instruction B. /// Return true if value A properly dominates instruction B.
bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) { bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
if (auto *aInst = a->getDefiningInst()) if (auto *aInst = a->getDefiningInst())
return properlyDominates(aInst, b); return properlyDominates(aInst, b);

View File

@ -124,14 +124,14 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
return tripCountExpr.getLargestKnownDivisor(); return tripCountExpr.getLargestKnownDivisor();
} }
bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) { bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isa<ForStmt>(iv) && "iv must be a ForStmt"); assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
assert(index.getType().isa<IndexType>() && "index must be of IndexType"); assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationStmt *, 4> affineApplyOps; SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<MLValue *>(&index)}, affineApplyOps); getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
if (affineApplyOps.empty()) { if (affineApplyOps.empty()) {
// Pointer equality test because of MLValue pointer semantics. // Pointer equality test because of Value pointer semantics.
return &index != &iv; return &index != &iv;
} }
@ -155,13 +155,13 @@ bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
} }
assert(idx < std::numeric_limits<unsigned>::max()); assert(idx < std::numeric_limits<unsigned>::max());
return !AffineValueMap(*composeOp) return !AffineValueMap(*composeOp)
.isFunctionOf(idx, &const_cast<MLValue &>(iv)); .isFunctionOf(idx, &const_cast<Value &>(iv));
} }
llvm::DenseSet<const MLValue *> llvm::DenseSet<const Value *>
mlir::getInvariantAccesses(const MLValue &iv, mlir::getInvariantAccesses(const Value &iv,
llvm::ArrayRef<const MLValue *> indices) { llvm::ArrayRef<const Value *> indices) {
llvm::DenseSet<const MLValue *> res; llvm::DenseSet<const Value *> res;
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
auto *val = indices[idx]; auto *val = indices[idx];
if (isAccessInvariant(iv, *val)) { if (isAccessInvariant(iv, *val)) {
@ -191,7 +191,7 @@ mlir::getInvariantAccesses(const MLValue &iv,
/// ///
// TODO(ntv): check strides. // TODO(ntv): check strides.
template <typename LoadOrStoreOp> template <typename LoadOrStoreOp>
static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp, static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp,
unsigned fastestVaryingDim) { unsigned fastestVaryingDim) {
static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value || static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value ||
std::is_same<LoadOrStoreOp, StoreOp>::value, std::is_same<LoadOrStoreOp, StoreOp>::value,
@ -220,7 +220,7 @@ static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
if (fastestVaryingDim == (numIndices - 1) - d++) { if (fastestVaryingDim == (numIndices - 1) - d++) {
continue; continue;
} }
if (!isAccessInvariant(iv, cast<MLValue>(*index))) { if (!isAccessInvariant(iv, *index)) {
return false; return false;
} }
} }
@ -316,7 +316,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
// outside). // outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) { if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) { for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const MLValue *result = opStmt->getResult(i); const Value *result = opStmt->getResult(i);
for (const StmtOperand &use : result->getUses()) { for (const StmtOperand &use : result->getUses()) {
// If an ancestor statement doesn't lie in the block of forStmt, there // If an ancestor statement doesn't lie in the block of forStmt, there
// is no shift to check. // is no shift to check.

View File

@ -70,7 +70,7 @@ static void addMemRefAccessIndices(
MemRefType memrefType, MemRefAccess *access) { MemRefType memrefType, MemRefAccess *access) {
access->indices.reserve(memrefType.getRank()); access->indices.reserve(memrefType.getRank());
for (auto *index : opIndices) { for (auto *index : opIndices) {
access->indices.push_back(cast<MLValue>(const_cast<SSAValue *>(index))); access->indices.push_back(const_cast<mlir::Value *>(index));
} }
} }
@ -79,13 +79,13 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) { MemRefAccess *access) {
access->opStmt = loadOrStoreOpStmt; access->opStmt = loadOrStoreOpStmt;
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = cast<MLValue>(loadOp->getMemRef()); access->memref = loadOp->getMemRef();
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(), addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
access); access);
} else { } else {
assert(loadOrStoreOpStmt->isa<StoreOp>()); assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->memref = cast<MLValue>(storeOp->getMemRef()); access->memref = storeOp->getMemRef();
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(), addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
access); access);
} }

View File

@ -150,21 +150,21 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
OpPointer<LoadOp> loadOp; OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp; OpPointer<StoreOp> storeOp;
unsigned rank; unsigned rank;
SmallVector<MLValue *, 4> indices; SmallVector<Value *, 4> indices;
if ((loadOp = opStmt->dyn_cast<LoadOp>())) { if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
rank = loadOp->getMemRefType().getRank(); rank = loadOp->getMemRefType().getRank();
for (auto *index : loadOp->getIndices()) { for (auto *index : loadOp->getIndices()) {
indices.push_back(cast<MLValue>(index)); indices.push_back(index);
} }
region->memref = cast<MLValue>(loadOp->getMemRef()); region->memref = loadOp->getMemRef();
region->setWrite(false); region->setWrite(false);
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) { } else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
rank = storeOp->getMemRefType().getRank(); rank = storeOp->getMemRefType().getRank();
for (auto *index : storeOp->getIndices()) { for (auto *index : storeOp->getIndices()) {
indices.push_back(cast<MLValue>(index)); indices.push_back(index);
} }
region->memref = cast<MLValue>(storeOp->getMemRef()); region->memref = storeOp->getMemRef();
region->setWrite(true); region->setWrite(true);
} else { } else {
return false; return false;
@ -201,7 +201,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
return false; return false;
} else { } else {
// Has to be a valid symbol. // Has to be a valid symbol.
auto *symbol = cast<MLValue>(accessValueMap.getOperand(i)); auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol()); assert(symbol->isValidSymbol());
// Check if the symbol is a constant. // Check if the symbol is a constant.
if (auto *opStmt = symbol->getDefiningStmt()) { if (auto *opStmt = symbol->getDefiningStmt()) {
@ -405,7 +405,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// Solve for src IVs in terms of dst IVs, symbols and constants. // Solve for src IVs in terms of dst IVs, symbols and constants.
SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null()); SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null());
std::vector<SmallVector<MLValue *, 2>> srcIvOperands(srcLoopNestSize); std::vector<SmallVector<Value *, 2>> srcIvOperands(srcLoopNestSize);
for (unsigned i = 0; i < srcLoopNestSize; ++i) { for (unsigned i = 0; i < srcLoopNestSize; ++i) {
// Skip IVs which are greater than requested loop depth. // Skip IVs which are greater than requested loop depth.
if (i >= srcLoopDepth) { if (i >= srcLoopDepth) {
@ -442,7 +442,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
srcIvOperands[i].push_back(dstLoopNest[dimId - 1]); srcIvOperands[i].push_back(dstLoopNest[dimId - 1]);
} }
// TODO(andydavis) Add symbols from the access function. Ideally, we // TODO(andydavis) Add symbols from the access function. Ideally, we
// should be able to query the constaint system for the MLValue associated // should be able to query the constaint system for the Value associated
// with a symbol identifiers in 'nonZeroSymbolIds'. // with a symbol identifiers in 'nonZeroSymbolIds'.
} }
@ -454,7 +454,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// of the loop at 'dstLoopDepth' in 'dstLoopNest'. // of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1]; auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin()); MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const MLValue *, MLValue *> operandMap; DenseMap<const Value *, Value *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap)); auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'. // Lookup stmt in cloned 'sliceLoopNest' at 'positions'.

View File

@ -108,7 +108,7 @@ static AffineMap makePermutationMap(
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) { const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster; using functional::makePtrDynCaster;
using functional::map; using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<SSAValue, MLValue>(), indices); auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(), SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
getAffineConstantExpr(0, context)); getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) { for (auto kvp : enclosingLoopToVectorDim) {

View File

@ -277,7 +277,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
/// Walk all of the code in this MLFunc and verify that the operands of any /// Walk all of the code in this MLFunc and verify that the operands of any
/// operations are properly dominated by their definitions. /// operations are properly dominated by their definitions.
bool MLFuncVerifier::verifyDominance() { bool MLFuncVerifier::verifyDominance() {
using HashTable = llvm::ScopedHashTable<const SSAValue *, bool>; using HashTable = llvm::ScopedHashTable<const Value *, bool>;
HashTable liveValues; HashTable liveValues;
HashTable::ScopeTy topScope(liveValues); HashTable::ScopeTy topScope(liveValues);

View File

@ -38,7 +38,6 @@
#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); } void Identifier::print(raw_ostream &os) const { os << str(); }
@ -967,7 +966,7 @@ public:
void printFunctionAttributes(const Function *func) { void printFunctionAttributes(const Function *func) {
return ModulePrinter::printFunctionAttributes(func); return ModulePrinter::printFunctionAttributes(func);
} }
void printOperand(const SSAValue *value) { printValueID(value); } void printOperand(const Value *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) { ArrayRef<const char *> elidedAttrs = {}) {
@ -977,7 +976,7 @@ public:
enum { nameSentinel = ~0U }; enum { nameSentinel = ~0U };
protected: protected:
void numberValueID(const SSAValue *value) { void numberValueID(const Value *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times"); assert(!valueIDs.count(value) && "Value numbered multiple times");
SmallString<32> specialNameBuffer; SmallString<32> specialNameBuffer;
@ -1004,7 +1003,7 @@ protected:
if (specialNameBuffer.empty()) { if (specialNameBuffer.empty()) {
switch (value->getKind()) { switch (value->getKind()) {
case SSAValueKind::BlockArgument: case Value::Kind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name. // If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner()) if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->getFunction()) if (auto *fn = block->getFunction())
@ -1015,12 +1014,12 @@ protected:
// Otherwise number it normally. // Otherwise number it normally.
valueIDs[value] = nextValueID++; valueIDs[value] = nextValueID++;
return; return;
case SSAValueKind::StmtResult: case Value::Kind::StmtResult:
// This is an uninteresting result, give it a boring number and be // This is an uninteresting result, give it a boring number and be
// done with it. // done with it.
valueIDs[value] = nextValueID++; valueIDs[value] = nextValueID++;
return; return;
case SSAValueKind::ForStmt: case Value::Kind::ForStmt:
specialName << 'i' << nextLoopID++; specialName << 'i' << nextLoopID++;
break; break;
} }
@ -1052,7 +1051,7 @@ protected:
} }
} }
void printValueID(const SSAValue *value, bool printResultNo = true) const { void printValueID(const Value *value, bool printResultNo = true) const {
int resultNo = -1; int resultNo = -1;
auto lookupValue = value; auto lookupValue = value;
@ -1093,8 +1092,8 @@ protected:
private: private:
/// This is the value ID for each SSA value in the current function. If this /// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames. /// returns ~0, then the valueID has an entry in valueNames.
DenseMap<const SSAValue *, unsigned> valueIDs; DenseMap<const Value *, unsigned> valueIDs;
DenseMap<const SSAValue *, StringRef> valueNames; DenseMap<const Value *, StringRef> valueNames;
/// This keeps track of all of the non-numeric names that are in flight, /// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates. /// allowing us to check for duplicates.
@ -1135,7 +1134,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
os << "\"("; os << "\"(";
interleaveComma(op->getOperands(), interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printValueID(value); }); [&](const Value *value) { printValueID(value); });
os << ')'; os << ')';
auto attrs = op->getAttrs(); auto attrs = op->getAttrs();
@ -1144,16 +1143,15 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
// Print the type signature of the operation. // Print the type signature of the operation.
os << " : ("; os << " : (";
interleaveComma(op->getOperands(), interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printType(value->getType()); }); [&](const Value *value) { printType(value->getType()); });
os << ") -> "; os << ") -> ";
if (op->getNumResults() == 1) { if (op->getNumResults() == 1) {
printType(op->getResult(0)->getType()); printType(op->getResult(0)->getType());
} else { } else {
os << '('; os << '(';
interleaveComma(op->getResults(), [&](const SSAValue *result) { interleaveComma(op->getResults(),
printType(result->getType()); [&](const Value *result) { printType(result->getType()); });
});
os << ')'; os << ')';
} }
} }
@ -1297,11 +1295,10 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) {
os << '('; os << '(';
interleaveComma(range, interleaveComma(range,
[this](const SSAValue *operand) { printValueID(operand); }); [this](const Value *operand) { printValueID(operand); });
os << " : "; os << " : ";
interleaveComma(range, [this](const SSAValue *operand) { interleaveComma(
printType(operand->getType()); range, [this](const Value *operand) { printType(operand->getType()); });
});
os << ')'; os << ')';
} }
@ -1576,20 +1573,20 @@ void IntegerSet::print(raw_ostream &os) const {
ModulePrinter(os, state).printIntegerSet(*this); ModulePrinter(os, state).printIntegerSet(*this);
} }
void SSAValue::print(raw_ostream &os) const { void Value::print(raw_ostream &os) const {
switch (getKind()) { switch (getKind()) {
case SSAValueKind::BlockArgument: case Value::Kind::BlockArgument:
// TODO: Improve this. // TODO: Improve this.
os << "<block argument>\n"; os << "<block argument>\n";
return; return;
case SSAValueKind::StmtResult: case Value::Kind::StmtResult:
return getDefiningStmt()->print(os); return getDefiningStmt()->print(os);
case SSAValueKind::ForStmt: case Value::Kind::ForStmt:
return cast<ForStmt>(this)->print(os); return cast<ForStmt>(this)->print(os);
} }
} }
void SSAValue::dump() const { print(llvm::errs()); } void Value::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const { void Instruction::print(raw_ostream &os) const {
auto *function = getFunction(); auto *function = getFunction();

View File

@ -281,7 +281,7 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
// If we are supposed to insert before a specific block, do so, otherwise add // If we are supposed to insert before a specific block, do so, otherwise add
// the block to the end of the function. // the block to the end of the function.
if (insertBefore) if (insertBefore)
function->getBlocks().insert(CFGFunction::iterator(insertBefore), b); function->getBlocks().insert(Function::iterator(insertBefore), b);
else else
function->push_back(b); function->push_back(b);
@ -291,16 +291,9 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
/// Create an operation given the fields represented as an OperationState. /// Create an operation given the fields represented as an OperationState.
OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) { OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
SmallVector<CFGValue *, 8> operands; auto *op = OperationInst::create(state.location, state.name, state.operands,
operands.reserve(state.operands.size()); state.types, state.attributes,
// Allow null operands as they act as sentinal barriers between successor state.successors, context);
// operand lists.
for (auto elt : state.operands)
operands.push_back(cast_or_null<CFGValue>(elt));
auto *op =
OperationInst::create(state.location, state.name, operands, state.types,
state.attributes, state.successors, context);
block->getStatements().insert(insertPoint, op); block->getStatements().insert(insertPoint, op);
return op; return op;
} }
@ -311,23 +304,17 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
/// Create an operation given the fields represented as an OperationState. /// Create an operation given the fields represented as an OperationState.
OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) { OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
SmallVector<MLValue *, 8> operands; auto *op = OperationStmt::create(state.location, state.name, state.operands,
operands.reserve(state.operands.size()); state.types, state.attributes,
for (auto elt : state.operands) state.successors, context);
operands.push_back(cast<MLValue>(elt));
auto *op =
OperationStmt::create(state.location, state.name, operands, state.types,
state.attributes, state.successors, context);
block->getStatements().insert(insertPoint, op); block->getStatements().insert(insertPoint, op);
return op; return op;
} }
ForStmt *MLFuncBuilder::createFor(Location location, ForStmt *MLFuncBuilder::createFor(Location location,
ArrayRef<MLValue *> lbOperands, ArrayRef<Value *> lbOperands, AffineMap lbMap,
AffineMap lbMap, ArrayRef<Value *> ubOperands, AffineMap ubMap,
ArrayRef<MLValue *> ubOperands, int64_t step) {
AffineMap ubMap, int64_t step) {
auto *stmt = auto *stmt =
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step); ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getStatements().insert(insertPoint, stmt); block->getStatements().insert(insertPoint, stmt);
@ -341,7 +328,7 @@ ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
return createFor(location, {}, lbMap, {}, ubMap, step); return createFor(location, {}, lbMap, {}, ubMap, step);
} }
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<MLValue *> operands, IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) { IntegerSet set) {
auto *stmt = IfStmt::create(location, operands, set); auto *stmt = IfStmt::create(location, operands, set);
block->getStatements().insert(insertPoint, stmt); block->getStatements().insert(insertPoint, stmt);

View File

@ -20,8 +20,8 @@
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@ -54,7 +54,7 @@ void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
// dimension operands parsed. // dimension operands parsed.
// Returns 'false' on success and 'true' on error. // Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser, bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<SSAValue *, 4> &operands, SmallVector<Value *, 4> &operands,
unsigned &numDims) { unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos; SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
@ -76,7 +76,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AffineApplyOp::build(Builder *builder, OperationState *result, void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap map, ArrayRef<SSAValue *> operands) { AffineMap map, ArrayRef<Value *> operands) {
result->addOperands(operands); result->addOperands(operands);
result->types.append(map.getNumResults(), builder->getIndexType()); result->types.append(map.getNumResults(), builder->getIndexType());
result->addAttribute("map", builder->getAffineMapAttr(map)); result->addAttribute("map", builder->getAffineMapAttr(map));
@ -133,23 +133,21 @@ bool AffineApplyOp::verify() const {
} }
// The result of the affine apply operation can be used as a dimension id if it // The result of the affine apply operation can be used as a dimension id if it
// is a CFG value or if it is an MLValue, and all the operands are valid // is a CFG value or if it is an Value, and all the operands are valid
// dimension ids. // dimension ids.
bool AffineApplyOp::isValidDim() const { bool AffineApplyOp::isValidDim() const {
for (auto *op : getOperands()) { for (auto *op : getOperands()) {
if (auto *v = dyn_cast<MLValue>(op)) if (!op->isValidDim())
if (!v->isValidDim())
return false; return false;
} }
return true; return true;
} }
// The result of the affine apply operation can be used as a symbol if it is // The result of the affine apply operation can be used as a symbol if it is
// a CFG value or if it is an MLValue, and all the operands are symbols. // a CFG value or if it is an Value, and all the operands are symbols.
bool AffineApplyOp::isValidSymbol() const { bool AffineApplyOp::isValidSymbol() const {
for (auto *op : getOperands()) { for (auto *op : getOperands()) {
if (auto *v = dyn_cast<MLValue>(op)) if (!op->isValidSymbol())
if (!v->isValidSymbol())
return false; return false;
} }
return true; return true;
@ -170,13 +168,13 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest, void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest,
ArrayRef<SSAValue *> operands) { ArrayRef<Value *> operands) {
result->addSuccessor(dest, operands); result->addSuccessor(dest, operands);
} }
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
BasicBlock *dest; BasicBlock *dest;
SmallVector<SSAValue *, 4> destOperands; SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands)) if (parser->parseSuccessorAndUseList(dest, destOperands))
return true; return true;
result->addSuccessor(dest, destOperands); result->addSuccessor(dest, destOperands);
@ -212,17 +210,16 @@ void BranchOp::eraseOperand(unsigned index) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void CondBranchOp::build(Builder *builder, OperationState *result, void CondBranchOp::build(Builder *builder, OperationState *result,
SSAValue *condition, BasicBlock *trueDest, Value *condition, BasicBlock *trueDest,
ArrayRef<SSAValue *> trueOperands, ArrayRef<Value *> trueOperands, BasicBlock *falseDest,
BasicBlock *falseDest, ArrayRef<Value *> falseOperands) {
ArrayRef<SSAValue *> falseOperands) {
result->addOperands(condition); result->addOperands(condition);
result->addSuccessor(trueDest, trueOperands); result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands); result->addSuccessor(falseDest, falseOperands);
} }
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<SSAValue *, 4> destOperands; SmallVector<Value *, 4> destOperands;
BasicBlock *dest; BasicBlock *dest;
OpAsmParser::OperandType condInfo; OpAsmParser::OperandType condInfo;
@ -446,7 +443,7 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void ReturnOp::build(Builder *builder, OperationState *result, void ReturnOp::build(Builder *builder, OperationState *result,
ArrayRef<SSAValue *> results) { ArrayRef<Value *> results) {
result->addOperands(results); result->addOperands(results);
} }
@ -465,8 +462,9 @@ void ReturnOp::print(OpAsmPrinter *p) const {
*p << ' '; *p << ' ';
p->printOperands(operand_begin(), operand_end()); p->printOperands(operand_begin(), operand_end());
*p << " : "; *p << " : ";
interleave(operand_begin(), operand_end(), interleave(
[&](const SSAValue *e) { p->printType(e->getType()); }, operand_begin(), operand_end(),
[&](const Value *e) { p->printType(e->getType()); },
[&]() { *p << ", "; }); [&]() { *p << ", "; });
} }
} }

View File

@ -23,7 +23,6 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Statements.h" #include "mlir/IR/Statements.h"
using namespace mlir; using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is /// Form the OperationName for an op with the specified string. This either is
@ -96,13 +95,13 @@ unsigned Operation::getNumOperands() const {
return llvm::cast<OperationStmt>(this)->getNumOperands(); return llvm::cast<OperationStmt>(this)->getNumOperands();
} }
SSAValue *Operation::getOperand(unsigned idx) { Value *Operation::getOperand(unsigned idx) {
return llvm::cast<OperationStmt>(this)->getOperand(idx); return llvm::cast<OperationStmt>(this)->getOperand(idx);
} }
void Operation::setOperand(unsigned idx, SSAValue *value) { void Operation::setOperand(unsigned idx, Value *value) {
auto *stmt = llvm::cast<OperationStmt>(this); auto *stmt = llvm::cast<OperationStmt>(this);
stmt->setOperand(idx, llvm::cast<MLValue>(value)); stmt->setOperand(idx, value);
} }
/// Return the number of results this operation has. /// Return the number of results this operation has.
@ -111,7 +110,7 @@ unsigned Operation::getNumResults() const {
} }
/// Return the indicated result. /// Return the indicated result.
SSAValue *Operation::getResult(unsigned idx) { Value *Operation::getResult(unsigned idx) {
return llvm::cast<OperationStmt>(this)->getResult(idx); return llvm::cast<OperationStmt>(this)->getResult(idx);
} }
@ -585,8 +584,8 @@ bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
// These functions are out-of-line implementations of the methods in BinaryOp, // These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated. // which avoids them being template instantiated/duplicated.
void impl::buildBinaryOp(Builder *builder, OperationState *result, void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
SSAValue *lhs, SSAValue *rhs) { Value *rhs) {
assert(lhs->getType() == rhs->getType()); assert(lhs->getType() == rhs->getType());
result->addOperands({lhs, rhs}); result->addOperands({lhs, rhs});
result->types.push_back(lhs->getType()); result->types.push_back(lhs->getType());
@ -613,8 +612,8 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
// CastOp implementation // CastOp implementation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result, void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
SSAValue *source, Type destType) { Type destType) {
result->addOperands(source); result->addOperands(source);
result->addTypes(destType); result->addTypes(destType);
} }

View File

@ -16,8 +16,8 @@
// ============================================================================= // =============================================================================
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Statements.h" #include "mlir/IR/Statements.h"
#include "mlir/IR/Value.h"
using namespace mlir; using namespace mlir;
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
@ -77,8 +77,8 @@ PatternRewriter::~PatternRewriter() {
/// clients can specify a list of other nodes that this replacement may make /// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those ops are dead, this will /// (perhaps transitively) dead. If any of those ops are dead, this will
/// remove them as well. /// remove them as well.
void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues, void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<SSAValue *> valuesToRemoveIfDead) { ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root. // Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op); notifyRootReplaced(op);
@ -97,14 +97,13 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
/// op and newOp are known to have the same number of results, replace the /// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp /// uses of op with uses of newOp
void PatternRewriter::replaceOpWithResultsOfAnotherOp( void PatternRewriter::replaceOpWithResultsOfAnotherOp(
Operation *op, Operation *newOp, Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
ArrayRef<SSAValue *> valuesToRemoveIfDead) {
assert(op->getNumResults() == newOp->getNumResults() && assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op"); "replacement op doesn't match results of original op");
if (op->getNumResults() == 1) if (op->getNumResults() == 1)
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
SmallVector<SSAValue *, 8> newResults(newOp->getResults().begin(), SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
newOp->getResults().end()); newOp->getResults().end());
return replaceOp(op, newResults, valuesToRemoveIfDead); return replaceOp(op, newResults, valuesToRemoveIfDead);
} }
@ -118,7 +117,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(
/// should remove if they are dead at this point. /// should remove if they are dead at this point.
/// ///
void PatternRewriter::updatedRootInPlace( void PatternRewriter::updatedRootInPlace(
Operation *op, ArrayRef<SSAValue *> valuesToRemoveIfDead) { Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root. // Notify the rewriter subclass that we're about to replace this root.
notifyRootUpdated(op); notifyRootUpdated(op);

View File

@ -1,4 +1,4 @@
//===- SSAValue.cpp - MLIR SSAValue Classes ------------===// //===- SSAValue.cpp - MLIR ValueClasses ------------===//
// //
// Copyright 2019 The MLIR Authors. // Copyright 2019 The MLIR Authors.
// //
@ -15,15 +15,15 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h" #include "mlir/IR/Statements.h"
#include "mlir/IR/Value.h"
using namespace mlir; using namespace mlir;
/// If this value is the result of an Instruction, return the instruction /// If this value is the result of an Instruction, return the instruction
/// that defines it. /// that defines it.
OperationInst *SSAValue::getDefiningInst() { OperationInst *Value::getDefiningInst() {
if (auto *result = dyn_cast<InstResult>(this)) if (auto *result = dyn_cast<InstResult>(this))
return result->getOwner(); return result->getOwner();
return nullptr; return nullptr;
@ -31,13 +31,13 @@ OperationInst *SSAValue::getDefiningInst() {
/// If this value is the result of an OperationStmt, return the statement /// If this value is the result of an OperationStmt, return the statement
/// that defines it. /// that defines it.
OperationStmt *SSAValue::getDefiningStmt() { OperationStmt *Value::getDefiningStmt() {
if (auto *result = dyn_cast<StmtResult>(this)) if (auto *result = dyn_cast<StmtResult>(this))
return result->getOwner(); return result->getOwner();
return nullptr; return nullptr;
} }
Operation *SSAValue::getDefiningOperation() { Operation *Value::getDefiningOperation() {
if (auto *inst = getDefiningInst()) if (auto *inst = getDefiningInst())
return inst; return inst;
if (auto *stmt = getDefiningStmt()) if (auto *stmt = getDefiningStmt())
@ -45,14 +45,14 @@ Operation *SSAValue::getDefiningOperation() {
return nullptr; return nullptr;
} }
/// Return the function that this SSAValue is defined in. /// Return the function that this Valueis defined in.
Function *SSAValue::getFunction() { Function *Value::getFunction() {
switch (getKind()) { switch (getKind()) {
case SSAValueKind::BlockArgument: case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction(); return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult: case Value::Kind::StmtResult:
return getDefiningStmt()->getFunction(); return getDefiningStmt()->getFunction();
case SSAValueKind::ForStmt: case Value::Kind::ForStmt:
return cast<ForStmt>(this)->getFunction(); return cast<ForStmt>(this)->getFunction();
} }
} }
@ -89,15 +89,6 @@ MLIRContext *IROperandOwner::getContext() const {
} }
} }
//===----------------------------------------------------------------------===//
// MLValue implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this MLValue is defined in.
MLFunction *MLValue::getFunction() {
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BlockArgument implementation. // BlockArgument implementation.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -85,18 +85,16 @@ MLFunction *Statement::getFunction() const {
return block ? block->getFunction() : nullptr; return block ? block->getFunction() : nullptr;
} }
MLValue *Statement::getOperand(unsigned idx) { Value *Statement::getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const Value *Statement::getOperand(unsigned idx) const {
return getStmtOperand(idx).get(); return getStmtOperand(idx).get();
} }
const MLValue *Statement::getOperand(unsigned idx) const { // Value can be used as a dimension id if it is valid as a symbol, or
return getStmtOperand(idx).get();
}
// MLValue can be used as a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation // it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments. // with dimension id arguments.
bool MLValue::isValidDim() const { bool Value::isValidDim() const {
if (auto *stmt = getDefiningStmt()) { if (auto *stmt = getDefiningStmt()) {
// Top level statement or constant operation is ok. // Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
@ -111,10 +109,10 @@ bool MLValue::isValidDim() const {
return true; return true;
} }
// MLValue can be used as a symbol if it is a constant, or it is defined at // Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol // the top level, or it is a result of affine apply operation with symbol
// arguments. // arguments.
bool MLValue::isValidSymbol() const { bool Value::isValidSymbol() const {
if (auto *stmt = getDefiningStmt()) { if (auto *stmt = getDefiningStmt()) {
// Top level statement or constant operation is ok. // Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>()) if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
@ -129,7 +127,7 @@ bool MLValue::isValidSymbol() const {
return isa<BlockArgument>(this); return isa<BlockArgument>(this);
} }
void Statement::setOperand(unsigned idx, MLValue *value) { void Statement::setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value); getStmtOperand(idx).set(value);
} }
@ -271,7 +269,7 @@ void Statement::dropAllReferences() {
/// Create a new OperationStmt with the specific fields. /// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location location, OperationName name, OperationStmt *OperationStmt::create(Location location, OperationName name,
ArrayRef<MLValue *> operands, ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes, ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes, ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors, ArrayRef<StmtBlock *> successors,
@ -420,8 +418,8 @@ void OperationInst::eraseOperand(unsigned index) {
// ForStmt // ForStmt
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands, ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands, AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) { AffineMap ubMap, int64_t step) {
assert(lbOperands.size() == lbMap.getNumInputs() && assert(lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map"); "lower bound operand count does not match the affine map");
@ -444,8 +442,8 @@ ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands,
ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap, ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step) AffineMap ubMap, int64_t step)
: Statement(Kind::For, location), : Statement(Statement::Kind::For, location),
MLValue(MLValueKind::ForStmt, Value(Value::Kind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())), Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) { body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
@ -462,11 +460,11 @@ const AffineBound ForStmt::getUpperBound() const {
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap); return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
} }
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) { void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs()); assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result"); assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands()); SmallVector<Value *, 4> ubOperands(getUpperBoundOperands());
operands.clear(); operands.clear();
operands.reserve(lbOperands.size() + ubMap.getNumInputs()); operands.reserve(lbOperands.size() + ubMap.getNumInputs());
@ -479,11 +477,11 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
this->lbMap = map; this->lbMap = map;
} }
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) { void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs()); assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result"); assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands()); SmallVector<Value *, 4> lbOperands(getLowerBoundOperands());
operands.clear(); operands.clear();
operands.reserve(lbOperands.size() + ubOperands.size()); operands.reserve(lbOperands.size() + ubOperands.size());
@ -553,7 +551,7 @@ bool ForStmt::matchingBoundOperandList() const {
unsigned numOperands = lbMap.getNumInputs(); unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare MLValue *'s. // Compare Value *'s.
if (getOperand(i) != getOperand(numOperands + i)) if (getOperand(i) != getOperand(numOperands + i))
return false; return false;
} }
@ -581,7 +579,7 @@ IfStmt::~IfStmt() {
// allocated through MLIRContext's bump pointer allocator. // allocated through MLIRContext's bump pointer allocator.
} }
IfStmt *IfStmt::create(Location location, ArrayRef<MLValue *> operands, IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
IntegerSet set) { IntegerSet set) {
unsigned numOperands = operands.size(); unsigned numOperands = operands.size();
assert(numOperands == set.getNumOperands() && assert(numOperands == set.getNumOperands() &&
@ -617,16 +615,16 @@ MLIRContext *IfStmt::getContext() const {
/// them alone if no entry is present). Replaces references to cloned /// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds /// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map. /// those mappings to the map.
Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap, Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
MLIRContext *context) const { MLIRContext *context) const {
// If the specified value is in operandMap, return the remapped value. // If the specified value is in operandMap, return the remapped value.
// Otherwise return the value itself. // Otherwise return the value itself.
auto remapOperand = [&](const MLValue *value) -> MLValue * { auto remapOperand = [&](const Value *value) -> Value * {
auto it = operandMap.find(value); auto it = operandMap.find(value);
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value); return it != operandMap.end() ? it->second : const_cast<Value *>(value);
}; };
SmallVector<MLValue *, 8> operands; SmallVector<Value *, 8> operands;
SmallVector<StmtBlock *, 2> successors; SmallVector<StmtBlock *, 2> successors;
if (auto *opStmt = dyn_cast<OperationStmt>(this)) { if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
operands.reserve(getNumOperands() + opStmt->getNumSuccessors()); operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
@ -683,10 +681,9 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
auto ubMap = forStmt->getUpperBoundMap(); auto ubMap = forStmt->getUpperBoundMap();
auto *newFor = ForStmt::create( auto *newFor = ForStmt::create(
getLoc(), getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap, lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap, ubMap, forStmt->getStep());
forStmt->getStep());
// Remember the induction variable mapping. // Remember the induction variable mapping.
operandMap[forStmt] = newFor; operandMap[forStmt] = newFor;
@ -716,6 +713,6 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
} }
Statement *Statement::clone(MLIRContext *context) const { Statement *Statement::clone(MLIRContext *context) const {
DenseMap<const MLValue *, MLValue *> operandMap; DenseMap<const Value *, Value *> operandMap;
return clone(operandMap, context); return clone(operandMap, context);
} }

View File

@ -42,7 +42,6 @@
#include "llvm/Support/SMLoc.h" #include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"
#include <algorithm> #include <algorithm>
using namespace mlir; using namespace mlir;
using llvm::MemoryBuffer; using llvm::MemoryBuffer;
using llvm::SMLoc; using llvm::SMLoc;
@ -1890,10 +1889,10 @@ public:
/// Given a reference to an SSA value and its type, return a reference. This /// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure. /// returns null on failure.
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type); Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
/// Register a definition of a value with the symbol table. /// Register a definition of a value with the symbol table.
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value); ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
// SSA parsing productions. // SSA parsing productions.
ParseResult parseSSAUse(SSAUseInfo &result); ParseResult parseSSAUse(SSAUseInfo &result);
@ -1903,9 +1902,9 @@ public:
ResultType parseSSADefOrUseAndType( ResultType parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type)> &action); const std::function<ResultType(SSAUseInfo, Type)> &action);
SSAValue *parseSSAUseAndType() { Value *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>( return parseSSADefOrUseAndType<Value *>(
[&](SSAUseInfo useInfo, Type type) -> SSAValue * { [&](SSAUseInfo useInfo, Type type) -> Value * {
return resolveSSAUse(useInfo, type); return resolveSSAUse(useInfo, type);
}); });
} }
@ -1920,9 +1919,8 @@ public:
Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc); Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
/// Parse a single operation successor and it's operand list. /// Parse a single operation successor and it's operand list.
virtual bool virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
parseSuccessorAndUseList(BasicBlock *&dest, SmallVectorImpl<Value *> &operands) = 0;
SmallVectorImpl<SSAValue *> &operands) = 0;
protected: protected:
FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {} FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {}
@ -1934,24 +1932,23 @@ private:
Kind kind; Kind kind;
/// This keeps track of all of the SSA values we are tracking, indexed by /// This keeps track of all of the SSA values we are tracking, indexed by
/// their name. This has one entry per result number. /// their name. This has one entry per result number.
llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values; llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
/// These are all of the placeholders we've made along with the location of /// These are all of the placeholders we've made along with the location of
/// their first reference, to allow checking for use of undefined values. /// their first reference, to allow checking for use of undefined values.
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders; DenseMap<Value *, SMLoc> forwardReferencePlaceholders;
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type); Value *createForwardReferencePlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference. /// Return true if this is a forward reference.
bool isForwardReferencePlaceholder(SSAValue *value) { bool isForwardReferencePlaceholder(Value *value) {
return forwardReferencePlaceholders.count(value); return forwardReferencePlaceholders.count(value);
} }
}; };
} // end anonymous namespace } // end anonymous namespace
/// Create and remember a new placeholder for a forward reference. /// Create and remember a new placeholder for a forward reference.
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) {
Type type) {
// Forward references are always created as instructions, even in ML // Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain. // functions, because we just need something with a def/use chain.
// //
@ -1969,7 +1966,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
/// Given an unbound reference to an SSA value and its type, return the value /// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure. /// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { Value *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = values[useInfo.name]; auto &entries = values[useInfo.name];
// If we have already seen a value of this name, return it. // If we have already seen a value of this name, return it.
@ -2010,7 +2007,7 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
} }
/// Register a definition of a value with the symbol table. /// Register a definition of a value with the symbol table.
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) { ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) {
auto &entries = values[useInfo.name]; auto &entries = values[useInfo.name];
// Make sure there is a slot for this value. // Make sure there is a slot for this value.
@ -2046,7 +2043,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
// Check for any forward references that are left. If we find any, error // Check for any forward references that are left. If we find any, error
// out. // out.
if (!forwardReferencePlaceholders.empty()) { if (!forwardReferencePlaceholders.empty()) {
SmallVector<std::pair<const char *, SSAValue *>, 4> errors; SmallVector<std::pair<const char *, Value *>, 4> errors;
// Iteration over the map isn't deterministic, so sort by source location. // Iteration over the map isn't deterministic, so sort by source location.
for (auto entry : forwardReferencePlaceholders) for (auto entry : forwardReferencePlaceholders)
errors.push_back({entry.second.getPointer(), entry.first}); errors.push_back({entry.second.getPointer(), entry.first});
@ -2399,9 +2396,8 @@ public:
return false; return false;
} }
bool bool parseSuccessorAndUseList(BasicBlock *&dest,
parseSuccessorAndUseList(BasicBlock *&dest, SmallVectorImpl<Value *> &operands) override {
SmallVectorImpl<SSAValue *> &operands) override {
// Defer successor parsing to the function parsers. // Defer successor parsing to the function parsers.
return parser.parseSuccessorAndUseList(dest, operands); return parser.parseSuccessorAndUseList(dest, operands);
} }
@ -2493,7 +2489,7 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; } llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type type, bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) override { SmallVectorImpl<Value *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location}; operand.location};
if (auto *value = parser.resolveSSAUse(operandInfo, type)) { if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
@ -2573,7 +2569,7 @@ public:
ParseResult parseFunctionBody(); ParseResult parseFunctionBody();
bool parseSuccessorAndUseList(BasicBlock *&dest, bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands); SmallVectorImpl<Value *> &operands);
private: private:
CFGFunction *function; CFGFunction *function;
@ -2636,7 +2632,7 @@ private:
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
/// ///
bool CFGFunctionParser::parseSuccessorAndUseList( bool CFGFunctionParser::parseSuccessorAndUseList(
BasicBlock *&dest, SmallVectorImpl<SSAValue *> &operands) { BasicBlock *&dest, SmallVectorImpl<Value *> &operands) {
// Verify branch is identifier and get the matching block. // Verify branch is identifier and get the matching block.
if (!getToken().is(Token::bare_identifier)) if (!getToken().is(Token::bare_identifier))
return emitError("expected basic block name"); return emitError("expected basic block name");
@ -2790,10 +2786,10 @@ private:
ParseResult parseForStmt(); ParseResult parseForStmt();
ParseResult parseIntConstant(int64_t &val); ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands, unsigned numDims, unsigned numOperands,
const char *affineStructName); const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map, ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower); bool isLower);
ParseResult parseIfStmt(); ParseResult parseIfStmt();
ParseResult parseElseClause(StmtBlock *elseClause); ParseResult parseElseClause(StmtBlock *elseClause);
@ -2801,7 +2797,7 @@ private:
ParseResult parseStmtBlock(StmtBlock *block); ParseResult parseStmtBlock(StmtBlock *block);
bool parseSuccessorAndUseList(BasicBlock *&dest, bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands) { SmallVectorImpl<Value *> &operands) {
assert(false && "MLFunctions do not have terminators with successors."); assert(false && "MLFunctions do not have terminators with successors.");
return true; return true;
} }
@ -2838,7 +2834,7 @@ ParseResult MLFunctionParser::parseForStmt() {
return ParseFailure; return ParseFailure;
// Parse lower bound. // Parse lower bound.
SmallVector<MLValue *, 4> lbOperands; SmallVector<Value *, 4> lbOperands;
AffineMap lbMap; AffineMap lbMap;
if (parseBound(lbOperands, lbMap, /*isLower*/ true)) if (parseBound(lbOperands, lbMap, /*isLower*/ true))
return ParseFailure; return ParseFailure;
@ -2847,7 +2843,7 @@ ParseResult MLFunctionParser::parseForStmt() {
return ParseFailure; return ParseFailure;
// Parse upper bound. // Parse upper bound.
SmallVector<MLValue *, 4> ubOperands; SmallVector<Value *, 4> ubOperands;
AffineMap ubMap; AffineMap ubMap;
if (parseBound(ubOperands, ubMap, /*isLower*/ false)) if (parseBound(ubOperands, ubMap, /*isLower*/ false))
return ParseFailure; return ParseFailure;
@ -2913,7 +2909,7 @@ ParseResult MLFunctionParser::parseIntConstant(int64_t &val) {
/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list? /// dim-and-symbol-use-list ::= dim-use-list symbol-use-list?
/// ///
ParseResult ParseResult
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands, MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands, unsigned numDims, unsigned numOperands,
const char *affineStructName) { const char *affineStructName) {
if (parseToken(Token::l_paren, "expected '('")) if (parseToken(Token::l_paren, "expected '('"))
@ -2942,18 +2938,17 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
// Resolve SSA uses. // Resolve SSA uses.
Type indexType = builder.getIndexType(); Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) { for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], indexType); Value *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval) if (!sval)
return ParseFailure; return ParseFailure;
auto *v = cast<MLValue>(sval); if (i < numDims && !sval->isValidDim())
if (i < numDims && !v->isValidDim())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a dimension id"); "' cannot be used as a dimension id");
if (i >= numDims && !v->isValidSymbol()) if (i >= numDims && !sval->isValidSymbol())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() + return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a symbol"); "' cannot be used as a symbol");
operands.push_back(v); operands.push_back(sval);
} }
return ParseSuccess; return ParseSuccess;
@ -2965,7 +2960,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list /// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal /// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
/// ///
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands, ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
AffineMap &map, bool isLower) { AffineMap &map, bool isLower) {
// 'min' / 'max' prefixes are syntactic sugar. Ignore them. // 'min' / 'max' prefixes are syntactic sugar. Ignore them.
if (isLower) if (isLower)
@ -3003,7 +2998,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
// TODO: improve error message when SSA value is not an affine integer. // TODO: improve error message when SSA value is not an affine integer.
// Currently it is 'use of value ... expects different type than prior uses' // Currently it is 'use of value ... expects different type than prior uses'
if (auto *value = resolveSSAUse(opInfo, builder.getIndexType())) if (auto *value = resolveSSAUse(opInfo, builder.getIndexType()))
operands.push_back(cast<MLValue>(value)); operands.push_back(value);
else else
return ParseFailure; return ParseFailure;
@ -3113,7 +3108,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
if (!set) if (!set)
return ParseFailure; return ParseFailure;
SmallVector<MLValue *, 4> operands; SmallVector<Value *, 4> operands;
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
"integer set")) "integer set"))
return ParseFailure; return ParseFailure;

View File

@ -23,8 +23,8 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
@ -78,8 +78,8 @@ struct MemRefCastFolder : public RewritePattern {
// AddFOp // AddFOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs, void AddFOp::build(Builder *builder, OperationState *result, Value *lhs,
SSAValue *rhs) { Value *rhs) {
assert(lhs->getType() == rhs->getType()); assert(lhs->getType() == rhs->getType());
result->addOperands({lhs, rhs}); result->addOperands({lhs, rhs});
result->types.push_back(lhs->getType()); result->types.push_back(lhs->getType());
@ -146,7 +146,7 @@ void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result, void AllocOp::build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<SSAValue *> operands) { MemRefType memrefType, ArrayRef<Value *> operands) {
result->addOperands(operands); result->addOperands(operands);
result->types.push_back(memrefType); result->types.push_back(memrefType);
} }
@ -247,8 +247,8 @@ struct SimplifyAllocConst : public RewritePattern {
// and keep track of the resultant memref type to build. // and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants; SmallVector<int, 4> newShapeConstants;
newShapeConstants.reserve(memrefType.getRank()); newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands; SmallVector<Value *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands; SmallVector<Value *, 4> droppedOperands;
unsigned dynamicDimPos = 0; unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@ -301,7 +301,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void CallOp::build(Builder *builder, OperationState *result, Function *callee, void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) { ArrayRef<Value *> operands) {
result->addOperands(operands); result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee)); result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults()); result->addTypes(callee->getType().getResults());
@ -370,7 +370,7 @@ bool CallOp::verify() const {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void CallIndirectOp::build(Builder *builder, OperationState *result, void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) { Value *callee, ArrayRef<Value *> operands) {
auto fnType = callee->getType().cast<FunctionType>(); auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee); result->operands.push_back(callee);
result->addOperands(operands); result->addOperands(operands);
@ -507,7 +507,7 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
} }
void CmpIOp::build(Builder *build, OperationState *result, void CmpIOp::build(Builder *build, OperationState *result,
CmpIPredicate predicate, SSAValue *lhs, SSAValue *rhs) { CmpIPredicate predicate, Value *lhs, Value *rhs) {
result->addOperands({lhs, rhs}); result->addOperands({lhs, rhs});
result->types.push_back(getI1SameShape(build, lhs->getType())); result->types.push_back(getI1SameShape(build, lhs->getType()));
result->addAttribute(getPredicateAttrName(), result->addAttribute(getPredicateAttrName(),
@ -580,8 +580,7 @@ bool CmpIOp::verify() const {
// DeallocOp // DeallocOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void DeallocOp::build(Builder *builder, OperationState *result, void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) {
SSAValue *memref) {
result->addOperands(memref); result->addOperands(memref);
} }
@ -615,7 +614,7 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void DimOp::build(Builder *builder, OperationState *result, void DimOp::build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index) { Value *memrefOrTensor, unsigned index) {
result->addOperands(memrefOrTensor); result->addOperands(memrefOrTensor);
auto type = builder->getIndexType(); auto type = builder->getIndexType();
result->addAttribute("index", builder->getIntegerAttr(type, index)); result->addAttribute("index", builder->getIntegerAttr(type, index));
@ -689,11 +688,11 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
void DmaStartOp::build(Builder *builder, OperationState *result, void DmaStartOp::build(Builder *builder, OperationState *result,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices, Value *srcMemRef, ArrayRef<Value *> srcIndices,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices, Value *destMemRef, ArrayRef<Value *> destIndices,
SSAValue *numElements, SSAValue *tagMemRef, Value *numElements, Value *tagMemRef,
ArrayRef<SSAValue *> tagIndices, SSAValue *stride, ArrayRef<Value *> tagIndices, Value *stride,
SSAValue *elementsPerStride) { Value *elementsPerStride) {
result->addOperands(srcMemRef); result->addOperands(srcMemRef);
result->addOperands(srcIndices); result->addOperands(srcIndices);
result->addOperands(destMemRef); result->addOperands(destMemRef);
@ -836,8 +835,8 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
void DmaWaitOp::build(Builder *builder, OperationState *result, void DmaWaitOp::build(Builder *builder, OperationState *result,
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices, Value *tagMemRef, ArrayRef<Value *> tagIndices,
SSAValue *numElements) { Value *numElements) {
result->addOperands(tagMemRef); result->addOperands(tagMemRef);
result->addOperands(tagIndices); result->addOperands(tagIndices);
result->addOperands(numElements); result->addOperands(numElements);
@ -896,8 +895,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void ExtractElementOp::build(Builder *builder, OperationState *result, void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate, Value *aggregate, ArrayRef<Value *> indices) {
ArrayRef<SSAValue *> indices) {
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>(); auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate); result->addOperands(aggregate);
result->addOperands(indices); result->addOperands(indices);
@ -955,8 +953,8 @@ bool ExtractElementOp::verify() const {
// LoadOp // LoadOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref, void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<SSAValue *> indices) { ArrayRef<Value *> indices) {
auto memrefType = memref->getType().cast<MemRefType>(); auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref); result->addOperands(memref);
result->addOperands(indices); result->addOperands(indices);
@ -1130,9 +1128,8 @@ void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SelectOp // SelectOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void SelectOp::build(Builder *builder, OperationState *result, void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
SSAValue *condition, SSAValue *trueValue, Value *trueValue, Value *falseValue) {
SSAValue *falseValue) {
result->addOperands({condition, trueValue, falseValue}); result->addOperands({condition, trueValue, falseValue});
result->addTypes(trueValue->getType()); result->addTypes(trueValue->getType());
} }
@ -1201,8 +1198,8 @@ Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void StoreOp::build(Builder *builder, OperationState *result, void StoreOp::build(Builder *builder, OperationState *result,
SSAValue *valueToStore, SSAValue *memref, Value *valueToStore, Value *memref,
ArrayRef<SSAValue *> indices) { ArrayRef<Value *> indices) {
result->addOperands(valueToStore); result->addOperands(valueToStore);
result->addOperands(memref); result->addOperands(memref);
result->addOperands(indices); result->addOperands(indices);

View File

@ -72,10 +72,10 @@ static bool verifyPermutationMap(AffineMap permutationMap,
} }
void VectorTransferReadOp::build(Builder *builder, OperationState *result, void VectorTransferReadOp::build(Builder *builder, OperationState *result,
VectorType vectorType, SSAValue *srcMemRef, VectorType vectorType, Value *srcMemRef,
ArrayRef<SSAValue *> srcIndices, ArrayRef<Value *> srcIndices,
AffineMap permutationMap, AffineMap permutationMap,
Optional<SSAValue *> paddingValue) { Optional<Value *> paddingValue) {
result->addOperands(srcMemRef); result->addOperands(srcMemRef);
result->addOperands(srcIndices); result->addOperands(srcIndices);
if (paddingValue) { if (paddingValue) {
@ -100,21 +100,20 @@ VectorTransferReadOp::getIndices() const {
return {begin, end}; return {begin, end};
} }
Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() { Optional<Value *> VectorTransferReadOp::getPaddingValue() {
auto memRefRank = getMemRefType().getRank(); auto memRefRank = getMemRefType().getRank();
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
return None; return None;
} }
return Optional<SSAValue *>( return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
getOperand(Offsets::FirstIndexOffset + memRefRank));
} }
Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const { Optional<const Value *> VectorTransferReadOp::getPaddingValue() const {
auto memRefRank = getMemRefType().getRank(); auto memRefRank = getMemRefType().getRank();
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
return None; return None;
} }
return Optional<const SSAValue *>( return Optional<const Value *>(
getOperand(Offsets::FirstIndexOffset + memRefRank)); getOperand(Offsets::FirstIndexOffset + memRefRank));
} }
@ -136,7 +135,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) const {
// Construct the FunctionType and print it. // Construct the FunctionType and print it.
llvm::SmallVector<Type, 8> inputs{getMemRefType()}; llvm::SmallVector<Type, 8> inputs{getMemRefType()};
// Must have at least one actual index, see verify. // Must have at least one actual index, see verify.
const SSAValue *firstIndex = *(getIndices().begin()); const Value *firstIndex = *(getIndices().begin());
Type indexType = firstIndex->getType(); Type indexType = firstIndex->getType();
inputs.append(getMemRefType().getRank(), indexType); inputs.append(getMemRefType().getRank(), indexType);
if (optionalPaddingValue) { if (optionalPaddingValue) {
@ -295,8 +294,8 @@ bool VectorTransferReadOp::verify() const {
// VectorTransferWriteOp // VectorTransferWriteOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void VectorTransferWriteOp::build(Builder *builder, OperationState *result, void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
SSAValue *srcVector, SSAValue *dstMemRef, Value *srcVector, Value *dstMemRef,
ArrayRef<SSAValue *> dstIndices, ArrayRef<Value *> dstIndices,
AffineMap permutationMap) { AffineMap permutationMap) {
result->addOperands({srcVector, dstMemRef}); result->addOperands({srcVector, dstMemRef});
result->addOperands(dstIndices); result->addOperands(dstIndices);
@ -457,7 +456,7 @@ bool VectorTransferWriteOp::verify() const {
// VectorTypeCastOp // VectorTypeCastOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void VectorTypeCastOp::build(Builder *builder, OperationState *result, void VectorTypeCastOp::build(Builder *builder, OperationState *result,
SSAValue *srcVector, Type dstType) { Value *srcVector, Type dstType) {
result->addOperands(srcVector); result->addOperands(srcVector);
result->addTypes(dstType); result->addTypes(dstType);
} }

View File

@ -111,7 +111,7 @@ private:
/// descriptor and get the pointer to the element indexed by the linearized /// descriptor and get the pointer to the element indexed by the linearized
/// subscript. Return nullptr on errors. /// subscript. Return nullptr on errors.
llvm::Value *emitMemRefElementAccess( llvm::Value *emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op, const Value *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices); llvm::iterator_range<Operation::const_operand_iterator> opIndices);
/// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create /// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create
@ -136,12 +136,12 @@ private:
/// Create a single LLVM value of struct type that includes the list of /// Create a single LLVM value of struct type that includes the list of
/// given MLIR values. The `values` list must contain at least 2 elements. /// given MLIR values. The `values` list must contain at least 2 elements.
llvm::Value *packValues(ArrayRef<const SSAValue *> values); llvm::Value *packValues(ArrayRef<const Value *> values);
/// Extract a list of `num` LLVM values from a `value` of struct type. /// Extract a list of `num` LLVM values from a `value` of struct type.
SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num); SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num);
llvm::DenseMap<const Function *, llvm::Function *> functionMapping; llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping; llvm::DenseMap<const Value *, llvm::Value *> valueMapping;
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping; llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
llvm::LLVMContext &llvmContext; llvm::LLVMContext &llvmContext;
llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder; llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder;
@ -316,7 +316,7 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
} }
llvm::Value *ModuleLowerer::emitMemRefElementAccess( llvm::Value *ModuleLowerer::emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op, const Value *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices) { llvm::iterator_range<Operation::const_operand_iterator> opIndices) {
auto type = memRef->getType().dyn_cast<MemRefType>(); auto type = memRef->getType().dyn_cast<MemRefType>();
assert(type && "expected memRef value to have a MemRef type"); assert(type && "expected memRef value to have a MemRef type");
@ -340,7 +340,7 @@ llvm::Value *ModuleLowerer::emitMemRefElementAccess(
// Obtain the list of access subscripts as values and linearize it given the // Obtain the list of access subscripts as values and linearize it given the
// list of sizes. // list of sizes.
auto indices = functional::map( auto indices = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); }, [this](const Value *value) { return valueMapping.lookup(value); },
opIndices); opIndices);
auto subscript = linearizeSubscripts(indices, sizes); auto subscript = linearizeSubscripts(indices, sizes);
@ -460,11 +460,11 @@ llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) {
} }
// Create an undef struct value and insert individual values into it. // Create an undef struct value and insert individual values into it.
llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) { llvm::Value *ModuleLowerer::packValues(ArrayRef<const Value *> values) {
assert(values.size() > 1 && "cannot pack less than 2 values"); assert(values.size() > 1 && "cannot pack less than 2 values");
auto types = auto types =
functional::map([](const SSAValue *v) { return v->getType(); }, values); functional::map([](const Value *v) { return v->getType(); }, values);
llvm::Type *packedType = getPackedResultType(types); llvm::Type *packedType = getPackedResultType(types);
llvm::Value *packed = llvm::UndefValue::get(packedType); llvm::Value *packed = llvm::UndefValue::get(packedType);
@ -641,7 +641,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
return false; return false;
} }
if (auto dimOp = inst.dyn_cast<DimOp>()) { if (auto dimOp = inst.dyn_cast<DimOp>()) {
const SSAValue *container = dimOp->getOperand(); const Value *container = dimOp->getOperand();
MemRefType type = container->getType().dyn_cast<MemRefType>(); MemRefType type = container->getType().dyn_cast<MemRefType>();
if (!type) if (!type)
return dimOp->emitError("only memref types are supported"); return dimOp->emitError("only memref types are supported");
@ -672,7 +672,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
if (auto callOp = inst.dyn_cast<CallOp>()) { if (auto callOp = inst.dyn_cast<CallOp>()) {
auto operands = functional::map( auto operands = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); }, [this](const Value *value) { return valueMapping.lookup(value); },
callOp->getOperands()); callOp->getOperands());
auto numResults = callOp->getNumResults(); auto numResults = callOp->getNumResults();
llvm::Value *result = llvm::Value *result =
@ -779,10 +779,9 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
// Get the SSA value passed to the current block from the terminator instruction // Get the SSA value passed to the current block from the terminator instruction
// of its predecessor. // of its predecessor.
static const SSAValue *getPHISourceValue(const BasicBlock *current, static const Value *getPHISourceValue(const BasicBlock *current,
const BasicBlock *pred, const BasicBlock *pred,
unsigned numArguments, unsigned numArguments, unsigned index) {
unsigned index) {
auto &terminator = *pred->getTerminator(); auto &terminator = *pred->getTerminator();
if (terminator.isa<BranchOp>()) { if (terminator.isa<BranchOp>()) {
return terminator.getOperand(index); return terminator.getOperand(index);

View File

@ -30,13 +30,12 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
ConstantFold() : FunctionPass(&ConstantFold::passID) {} ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding. // All constants in the function post folding.
SmallVector<SSAValue *, 8> existingConstants; SmallVector<Value *, 8> existingConstants;
// Operation statements that were folded and that need to be erased. // Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase; std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>; using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
bool foldOperation(Operation *op, bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants,
SmallVectorImpl<SSAValue *> &existingConstants,
ConstantFactoryType constantFactory); ConstantFactoryType constantFactory);
void visitOperationStmt(OperationStmt *stmt); void visitOperationStmt(OperationStmt *stmt);
void visitForStmt(ForStmt *stmt); void visitForStmt(ForStmt *stmt);
@ -54,9 +53,8 @@ char ConstantFold::passID = 0;
/// ///
/// This returns false if the operation was successfully folded. /// This returns false if the operation was successfully folded.
bool ConstantFold::foldOperation(Operation *op, bool ConstantFold::foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants, SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory) { ConstantFactoryType constantFactory) {
// If this operation is already a constant, just remember it for cleanup // If this operation is already a constant, just remember it for cleanup
// later, and don't try to fold it. // later, and don't try to fold it.
if (auto constant = op->dyn_cast<ConstantOp>()) { if (auto constant = op->dyn_cast<ConstantOp>()) {
@ -114,7 +112,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
if (!inst) if (!inst)
continue; continue;
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type type) -> Value * {
builder.setInsertionPoint(inst); builder.setInsertionPoint(inst);
return builder.create<ConstantOp>(inst->getLoc(), value, type); return builder.create<ConstantOp>(inst->getLoc(), value, type);
}; };
@ -142,7 +140,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding. // Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) { void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type type) -> Value * {
MLFuncBuilder builder(stmt); MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type); return builder.create<ConstantOp>(stmt->getLoc(), value, type);
}; };

View File

@ -50,28 +50,28 @@ public:
void visitOperationStmt(OperationStmt *opStmt); void visitOperationStmt(OperationStmt *opStmt);
private: private:
CFGValue *getConstantIndexValue(int64_t value); Value *getConstantIndexValue(int64_t value);
void visitStmtBlock(StmtBlock *stmtBlock); void visitStmtBlock(StmtBlock *stmtBlock);
CFGValue *buildMinMaxReductionSeq( Value *buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate, Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values); llvm::iterator_range<Operation::result_iterator> values);
CFGFunction *cfgFunc; CFGFunction *cfgFunc;
CFGFuncBuilder builder; CFGFuncBuilder builder;
// Mapping between original MLValues and lowered CFGValues. // Mapping between original Values and lowered Values.
llvm::DenseMap<const MLValue *, CFGValue *> valueRemapping; llvm::DenseMap<const Value *, Value *> valueRemapping;
}; };
} // end anonymous namespace } // end anonymous namespace
// Return a vector of OperationStmt's arguments as SSAValues. For each // Return a vector of OperationStmt's arguments as Values. For each
// statement operands, represented as MLValue, lookup its CFGValue conterpart in // statement operands, represented as Value, lookup its Value conterpart in
// the valueRemapping table. // the valueRemapping table.
static llvm::SmallVector<SSAValue *, 4> static llvm::SmallVector<mlir::Value *, 4>
operandsAs(Statement *opStmt, operandsAs(Statement *opStmt,
const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) { const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
llvm::SmallVector<SSAValue *, 4> operands; llvm::SmallVector<Value *, 4> operands;
for (const MLValue *operand : opStmt->getOperands()) { for (const Value *operand : opStmt->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined"); assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand)); operands.push_back(valueRemapping.lookup(operand));
} }
@ -81,8 +81,8 @@ operandsAs(Statement *opStmt,
// Convert an operation statement into an operation instruction. // Convert an operation statement into an operation instruction.
// //
// The operation description (name, number and types of operands or results) // The operation description (name, number and types of operands or results)
// remains the same but the values must be updated to be CFGValues. Update the // remains the same but the values must be updated to be Values. Update the
// mapping MLValue->CFGValue as the conversion is performed. The operation // mapping Value->Value as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region). // instruction is appended to current block (end of SESE region).
void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) { void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
// Set up basic operation state (context, name, operands). // Set up basic operation state (context, name, operands).
@ -90,11 +90,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
opStmt->getName()); opStmt->getName());
state.addOperands(operandsAs(opStmt, valueRemapping)); state.addOperands(operandsAs(opStmt, valueRemapping));
// Set up operation return types. The corresponding SSAValues will become // Set up operation return types. The corresponding Values will become
// available after the operation is created. // available after the operation is created.
state.addTypes( state.addTypes(functional::map(
functional::map([](SSAValue *result) { return result->getType(); }, [](Value *result) { return result->getType(); }, opStmt->getResults()));
opStmt->getResults()));
// Copy attributes. // Copy attributes.
for (auto attr : opStmt->getAttrs()) { for (auto attr : opStmt->getAttrs()) {
@ -112,10 +111,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
} }
} }
// Create a CFGValue for the given integer constant of index type. // Create a Value for the given integer constant of index type.
CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) { Value *FunctionConverter::getConstantIndexValue(int64_t value) {
auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value); auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value);
return cast<CFGValue>(op->getResult()); return op->getResult();
} }
// Visit all statements in the given statement block. // Visit all statements in the given statement block.
@ -135,18 +134,18 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) {
// Multiple values are scanned in a linear sequence. This creates a data // Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to // dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes. // recognize as a reduction by the subsequent passes.
CFGValue *FunctionConverter::buildMinMaxReductionSeq( Value *FunctionConverter::buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate, Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values) { llvm::iterator_range<Operation::result_iterator> values) {
assert(!llvm::empty(values) && "empty min/max chain"); assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin(); auto valueIt = values.begin();
CFGValue *value = cast<CFGValue>(*valueIt++); Value *value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) { for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt); auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
auto selectOp = auto selectOp =
builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt); builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt);
value = cast<CFGValue>(selectOp->getResult()); value = selectOp->getResult();
} }
return value; return value;
@ -231,9 +230,9 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// The loop condition block has an argument for loop induction variable. // The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block // Create it upfront and make the loop induction variable -> basic block
// argument remapping available to the following instructions. ForStatement // argument remapping available to the following instructions. ForStatement
// is-a MLValue corresponding to the loop induction variable. // is-a Value corresponding to the loop induction variable.
builder.setInsertionPoint(loopConditionBlock); builder.setInsertionPoint(loopConditionBlock);
CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType()); Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
valueRemapping.insert(std::make_pair(forStmt, iv)); valueRemapping.insert(std::make_pair(forStmt, iv));
// Recursively construct loop body region. // Recursively construct loop body region.
@ -251,7 +250,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp = auto stepOp =
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv); builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult(0)); Value *nextIvValue = stepOp->getResult(0);
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
nextIvValue); nextIvValue);
@ -260,20 +259,19 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPoint(loopInitBlock); builder.setInsertionPoint(loopInitBlock);
// Compute loop bounds using affine_apply after remapping its operands. // Compute loop bounds using affine_apply after remapping its operands.
auto remapOperands = [this](const SSAValue *value) -> SSAValue * { auto remapOperands = [this](const Value *value) -> Value * {
const MLValue *mlValue = dyn_cast<MLValue>(value); return valueRemapping.lookup(value);
return valueRemapping.lookup(mlValue);
}; };
auto operands = auto operands =
functional::map(remapOperands, forStmt->getLowerBoundOperands()); functional::map(remapOperands, forStmt->getLowerBoundOperands());
auto lbAffineApply = builder.create<AffineApplyOp>( auto lbAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
CFGValue *lowerBound = buildMinMaxReductionSeq( Value *lowerBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
auto ubAffineApply = builder.create<AffineApplyOp>( auto ubAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
CFGValue *upperBound = buildMinMaxReductionSeq( Value *upperBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock, builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound); lowerBound);
@ -281,10 +279,10 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPoint(loopConditionBlock); builder.setInsertionPoint(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>( auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = cast<CFGValue>(comparisonOp->getResult()); auto comparisonResult = comparisonOp->getResult();
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult, builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, ArrayRef<SSAValue *>(), loopBodyFirstBlock, ArrayRef<Value *>(),
postLoopBlock, ArrayRef<SSAValue *>()); postLoopBlock, ArrayRef<Value *>());
// Finally, make sure building can continue by setting the post-loop block // Finally, make sure building can continue by setting the post-loop block
// (end of loop SESE region) as the insertion point. // (end of loop SESE region) as the insertion point.
@ -401,7 +399,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// If the test succeeds, jump to the next block testing testing the next // If the test succeeds, jump to the next block testing testing the next
// conjunct of the condition in the similar way. When all conjuncts have been // conjunct of the condition in the similar way. When all conjuncts have been
// handled, jump to the 'then' block instead. // handled, jump to the 'then' block instead.
SSAValue *zeroConstant = getConstantIndexValue(0); Value *zeroConstant = getConstantIndexValue(0);
ifConditionExtraBlocks.push_back(thenBlock); ifConditionExtraBlocks.push_back(thenBlock);
for (auto tuple : for (auto tuple :
llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(), llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(),
@ -416,16 +414,16 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
integerSet.getNumSymbols(), constraintExpr, {}); integerSet.getNumSymbols(), constraintExpr, {});
auto affineApplyOp = builder.create<AffineApplyOp>( auto affineApplyOp = builder.create<AffineApplyOp>(
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
SSAValue *affResult = affineApplyOp->getResult(0); Value *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch. // Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>( auto comparisonOp = builder.create<CmpIOp>(
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
affResult, zeroConstant); affResult, zeroConstant);
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(), builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
nextBlock, /*trueArgs*/ ArrayRef<SSAValue *>(), nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
elseBlock, elseBlock,
/*falseArgs*/ ArrayRef<SSAValue *>()); /*falseArgs*/ ArrayRef<Value *>());
builder.setInsertionPoint(nextBlock); builder.setInsertionPoint(nextBlock);
} }
ifConditionExtraBlocks.pop_back(); ifConditionExtraBlocks.pop_back();
@ -468,10 +466,10 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// of the current region. The SESE invariant allows us to easily handle nested // of the current region. The SESE invariant allows us to easily handle nested
// structures of arbitrary complexity. // structures of arbitrary complexity.
// //
// During the conversion, we maintain a mapping between the MLValues present in // During the conversion, we maintain a mapping between the Values present in
// the original function and their CFGValue images in the function under // the original function and their Value images in the function under
// construction. When an MLValue is used, it gets replaced with the // construction. When an Value is used, it gets replaced with the
// corresponding CFGValue that has been defined previously. The value flow // corresponding Value that has been defined previously. The value flow
// starts with function arguments converted to basic block arguments. // starts with function arguments converted to basic block arguments.
CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) { CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
auto outerBlock = builder.createBlock(); auto outerBlock = builder.createBlock();
@ -482,8 +480,8 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
outerBlock->addArguments(mlFunc->getType().getInputs()); outerBlock->addArguments(mlFunc->getType().getInputs());
assert(mlFunc->getNumArguments() == outerBlock->getNumArguments()); assert(mlFunc->getNumArguments() == outerBlock->getNumArguments());
for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) { for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) {
const MLValue *mlArgument = mlFunc->getArgument(i); const Value *mlArgument = mlFunc->getArgument(i);
CFGValue *cfgArgument = outerBlock->getArgument(i); Value *cfgArgument = outerBlock->getArgument(i);
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
} }

View File

@ -76,7 +76,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
// Map from original memref's to the DMA buffers that their accesses are // Map from original memref's to the DMA buffers that their accesses are
// replaced with. // replaced with.
DenseMap<SSAValue *, SSAValue *> fastBufferMap; DenseMap<Value *, Value *> fastBufferMap;
// Slow memory space associated with DMAs. // Slow memory space associated with DMAs.
const unsigned slowMemorySpace; const unsigned slowMemorySpace;
@ -195,11 +195,11 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// Indices to use for the DmaStart op. // Indices to use for the DmaStart op.
// Indices for the original memref being DMAed from/to. // Indices for the original memref being DMAed from/to.
SmallVector<SSAValue *, 4> memIndices; SmallVector<Value *, 4> memIndices;
// Indices for the faster buffer being DMAed into/from. // Indices for the faster buffer being DMAed into/from.
SmallVector<SSAValue *, 4> bufIndices; SmallVector<Value *, 4> bufIndices;
SSAValue *zeroIndex = top.create<ConstantIndexOp>(loc, 0); Value *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
unsigned rank = memRefType.getRank(); unsigned rank = memRefType.getRank();
SmallVector<int, 4> fastBufferShape; SmallVector<int, 4> fastBufferShape;
@ -226,10 +226,10 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// DMA generation is being done. // DMA generation is being done.
const FlatAffineConstraints *cst = region.getConstraints(); const FlatAffineConstraints *cst = region.getConstraints();
auto ids = cst->getIds(); auto ids = cst->getIds();
SmallVector<SSAValue *, 8> outerIVs; SmallVector<Value *, 8> outerIVs;
for (unsigned i = rank, e = ids.size(); i < e; i++) { for (unsigned i = rank, e = ids.size(); i < e; i++) {
auto id = cst->getIds()[i]; auto id = cst->getIds()[i];
assert(id.hasValue() && "MLValue id expected"); assert(id.hasValue() && "Value id expected");
outerIVs.push_back(id.getValue()); outerIVs.push_back(id.getValue());
} }
@ -253,15 +253,15 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// Set DMA start location for this dimension in the lower memory space // Set DMA start location for this dimension in the lower memory space
// memref. // memref.
if (auto caf = offset.dyn_cast<AffineConstantExpr>()) { if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
memIndices.push_back(cast<MLValue>( memIndices.push_back(
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult())); top.create<ConstantIndexOp>(loc, caf.getValue())->getResult());
} else { } else {
// The coordinate for the start location is just the lower bound along the // The coordinate for the start location is just the lower bound along the
// corresponding dimension on the memory region (stored in 'offset'). // corresponding dimension on the memory region (stored in 'offset').
auto map = top.getAffineMap( auto map = top.getAffineMap(
cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {}); cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {});
memIndices.push_back(cast<MLValue>( memIndices.push_back(
b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0))); b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0));
} }
// The fast buffer is DMAed into at location zero; addressing is relative. // The fast buffer is DMAed into at location zero; addressing is relative.
bufIndices.push_back(zeroIndex); bufIndices.push_back(zeroIndex);
@ -272,7 +272,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
} }
// The faster memory space buffer. // The faster memory space buffer.
SSAValue *fastMemRef; Value *fastMemRef;
// Check if a buffer was already created. // Check if a buffer was already created.
// TODO(bondhugula): union across all memory op's per buffer. For now assuming // TODO(bondhugula): union across all memory op's per buffer. For now assuming
@ -321,8 +321,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
return false; return false;
} }
SSAValue *stride = nullptr; Value *stride = nullptr;
SSAValue *numEltPerStride = nullptr; Value *numEltPerStride = nullptr;
if (!strideInfos.empty()) { if (!strideInfos.empty()) {
stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride); stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
numEltPerStride = numEltPerStride =
@ -362,7 +362,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
} }
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {}); auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
// *Only* those uses within the body of 'forStmt' are replaced. // *Only* those uses within the body of 'forStmt' are replaced.
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef), replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap, /*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs, /*extraOperands=*/outerIVs,
/*domStmtFilter=*/&*forStmt->getBody()->begin()); /*domStmtFilter=*/&*forStmt->getBody()->begin());

View File

@ -83,22 +83,22 @@ FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt, static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) { MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) { if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = cast<MLValue>(loadOp->getMemRef()); access->memref = loadOp->getMemRef();
access->opStmt = loadOrStoreOpStmt; access->opStmt = loadOrStoreOpStmt;
auto loadMemrefType = loadOp->getMemRefType(); auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank()); access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) { for (auto *index : loadOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index)); access->indices.push_back(index);
} }
} else { } else {
assert(loadOrStoreOpStmt->isa<StoreOp>()); assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>(); auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->opStmt = loadOrStoreOpStmt; access->opStmt = loadOrStoreOpStmt;
access->memref = cast<MLValue>(storeOp->getMemRef()); access->memref = storeOp->getMemRef();
auto storeMemrefType = storeOp->getMemRefType(); auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank()); access->indices.reserve(storeMemrefType.getRank());
for (auto *index : storeOp->getIndices()) { for (auto *index : storeOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index)); access->indices.push_back(index);
} }
} }
} }
@ -178,20 +178,20 @@ public:
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {} Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
// Returns the load op count for 'memref'. // Returns the load op count for 'memref'.
unsigned getLoadOpCount(MLValue *memref) { unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0; unsigned loadOpCount = 0;
for (auto *loadOpStmt : loads) { for (auto *loadOpStmt : loads) {
if (memref == cast<MLValue>(loadOpStmt->cast<LoadOp>()->getMemRef())) if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
++loadOpCount; ++loadOpCount;
} }
return loadOpCount; return loadOpCount;
} }
// Returns the store op count for 'memref'. // Returns the store op count for 'memref'.
unsigned getStoreOpCount(MLValue *memref) { unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0; unsigned storeOpCount = 0;
for (auto *storeOpStmt : stores) { for (auto *storeOpStmt : stores) {
if (memref == cast<MLValue>(storeOpStmt->cast<StoreOp>()->getMemRef())) if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
++storeOpCount; ++storeOpCount;
} }
return storeOpCount; return storeOpCount;
@ -203,7 +203,7 @@ public:
// The id of the node at the other end of the edge. // The id of the node at the other end of the edge.
unsigned id; unsigned id;
// The memref on which this edge represents a dependence. // The memref on which this edge represents a dependence.
MLValue *memref; Value *memref;
}; };
// Map from node id to Node. // Map from node id to Node.
@ -227,13 +227,13 @@ public:
} }
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'. // Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) { void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
outEdges[srcId].push_back({dstId, memref}); outEdges[srcId].push_back({dstId, memref});
inEdges[dstId].push_back({srcId, memref}); inEdges[dstId].push_back({srcId, memref});
} }
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'. // Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) { void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
assert(inEdges.count(dstId) > 0); assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0); assert(outEdges.count(srcId) > 0);
// Remove 'srcId' from 'inEdges[dstId]'. // Remove 'srcId' from 'inEdges[dstId]'.
@ -253,7 +253,7 @@ public:
} }
// Returns the input edge count for node 'id' and 'memref'. // Returns the input edge count for node 'id' and 'memref'.
unsigned getInEdgeCount(unsigned id, MLValue *memref) { unsigned getInEdgeCount(unsigned id, Value *memref) {
unsigned inEdgeCount = 0; unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0) if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id]) for (auto &inEdge : inEdges[id])
@ -263,7 +263,7 @@ public:
} }
// Returns the output edge count for node 'id' and 'memref'. // Returns the output edge count for node 'id' and 'memref'.
unsigned getOutEdgeCount(unsigned id, MLValue *memref) { unsigned getOutEdgeCount(unsigned id, Value *memref) {
unsigned outEdgeCount = 0; unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0) if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id]) for (auto &outEdge : outEdges[id])
@ -347,7 +347,7 @@ public:
// dependence graph at a different depth. // dependence graph at a different depth.
bool MemRefDependenceGraph::init(MLFunction *f) { bool MemRefDependenceGraph::init(MLFunction *f) {
unsigned id = 0; unsigned id = 0;
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses; DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
for (auto &stmt : *f->getBody()) { for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) { if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
// Create graph node 'id' to represent top-level 'forStmt' and record // Create graph node 'id' to represent top-level 'forStmt' and record
@ -360,12 +360,12 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
Node node(id++, &stmt); Node node(id++, &stmt);
for (auto *opStmt : collector.loadOpStmts) { for (auto *opStmt : collector.loadOpStmts) {
node.loads.push_back(opStmt); node.loads.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef()); auto *memref = opStmt->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id); memrefAccesses[memref].insert(node.id);
} }
for (auto *opStmt : collector.storeOpStmts) { for (auto *opStmt : collector.storeOpStmts) {
node.stores.push_back(opStmt); node.stores.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef()); auto *memref = opStmt->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id); memrefAccesses[memref].insert(node.id);
} }
nodes.insert({node.id, node}); nodes.insert({node.id, node});
@ -375,7 +375,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
// Create graph node for top-level load op. // Create graph node for top-level load op.
Node node(id++, &stmt); Node node(id++, &stmt);
node.loads.push_back(opStmt); node.loads.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef()); auto *memref = opStmt->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id); memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node}); nodes.insert({node.id, node});
} }
@ -383,7 +383,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
// Create graph node for top-level store op. // Create graph node for top-level store op.
Node node(id++, &stmt); Node node(id++, &stmt);
node.stores.push_back(opStmt); node.stores.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef()); auto *memref = opStmt->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id); memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node}); nodes.insert({node.id, node});
} }
@ -477,8 +477,7 @@ public:
SmallVector<OperationStmt *, 4> loads = dstNode->loads; SmallVector<OperationStmt *, 4> loads = dstNode->loads;
while (!loads.empty()) { while (!loads.empty()) {
auto *dstLoadOpStmt = loads.pop_back_val(); auto *dstLoadOpStmt = loads.pop_back_val();
auto *memref = auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
cast<MLValue>(dstLoadOpStmt->cast<LoadOp>()->getMemRef());
// Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'. // Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1) if (dstNode->getLoadOpCount(memref) != 1)
continue; continue;

View File

@ -85,10 +85,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
for (unsigned i = 0; i < width; i++) { for (unsigned i = 0; i < width; i++) {
auto lbOperands = origLoops[i]->getLowerBoundOperands(); auto lbOperands = origLoops[i]->getLowerBoundOperands();
auto ubOperands = origLoops[i]->getUpperBoundOperands(); auto ubOperands = origLoops[i]->getUpperBoundOperands();
SmallVector<MLValue *, 4> newLbOperands(lbOperands.begin(), SmallVector<Value *, 4> newLbOperands(lbOperands.begin(), lbOperands.end());
lbOperands.end()); SmallVector<Value *, 4> newUbOperands(ubOperands.begin(), ubOperands.end());
SmallVector<MLValue *, 4> newUbOperands(ubOperands.begin(),
ubOperands.end());
newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap()); newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap());
newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap()); newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap());
newLoops[i]->setStep(tileSizes[i]); newLoops[i]->setStep(tileSizes[i]);
@ -112,8 +110,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
// Construct the upper bound map; the operands are the original operands // Construct the upper bound map; the operands are the original operands
// with 'i' (tile-space loop) appended to it. The new upper bound map is // with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended. // the original one with an additional expression i + tileSize appended.
SmallVector<MLValue *, 4> ubOperands( SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
origLoops[i]->getUpperBoundOperands());
ubOperands.push_back(newLoops[i]); ubOperands.push_back(newLoops[i]);
auto origUbMap = origLoops[i]->getUpperBoundMap(); auto origUbMap = origLoops[i]->getUpperBoundMap();
@ -191,8 +188,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// Move the loop body of the original nest to the new one. // Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
SmallVector<MLValue *, 6> origLoopIVs(band.begin(), band.end()); SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
SmallVector<Optional<MLValue *>, 6> ids(band.begin(), band.end()); SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
FlatAffineConstraints cst; FlatAffineConstraints cst;
getIndexSet(band, &cst); getIndexSet(band, &cst);

View File

@ -191,7 +191,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// unrollJamFactor. // unrollJamFactor.
if (mayBeConstantTripCount.hasValue() && if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) { mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap; DenseMap<const Value *, Value *> operandMap;
// Insert the cleanup loop right after 'forStmt'. // Insert the cleanup loop right after 'forStmt'.
MLFuncBuilder builder(forStmt->getBlock(), MLFuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt))); std::next(StmtBlock::iterator(forStmt)));
@ -219,7 +219,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Unroll and jam (appends unrollJamFactor-1 additional copies). // Unroll and jam (appends unrollJamFactor-1 additional copies).
for (unsigned i = 1; i < unrollJamFactor; i++) { for (unsigned i = 1; i < unrollJamFactor; i++) {
DenseMap<const MLValue *, MLValue *> operandMapping; DenseMap<const Value *, Value *> operandMapping;
// If the induction variable is used, create a remapping to the value for // If the induction variable is used, create a remapping to the value for
// this unrolled instance. // this unrolled instance.
@ -230,7 +230,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
auto *ivUnroll = auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0); ->getResult(0);
operandMapping[forStmt] = cast<MLValue>(ivUnroll); operandMapping[forStmt] = ivUnroll;
} }
// Clone the sub-block being unroll-jammed. // Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {

View File

@ -29,17 +29,14 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Pass.h" #include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h" #include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h" #include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/MLPatternLoweringPass.h" #include "mlir/Transforms/MLPatternLoweringPass.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
@ -62,26 +59,26 @@ using namespace mlir;
#define DEBUG_TYPE "lower-vector-transfers" #define DEBUG_TYPE "lower-vector-transfers"
/// Creates the SSAValue for the sum of `a` and `b` without building a /// Creates the Value for the sum of `a` and `b` without building a
/// full-fledged AffineMap for all indices. /// full-fledged AffineMap for all indices.
/// ///
/// Prerequisites: /// Prerequisites:
/// `a` and `b` must be of IndexType. /// `a` and `b` must be of IndexType.
static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) { static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) {
assert(v->getType().isa<IndexType>() && "v must be of IndexType"); assert(v->getType().isa<IndexType>() && "v must be of IndexType");
assert(w->getType().isa<IndexType>() && "w must be of IndexType"); assert(w->getType().isa<IndexType>() && "w must be of IndexType");
auto *context = b->getContext(); auto *context = b->getContext();
auto d0 = getAffineDimExpr(0, context); auto d0 = getAffineDimExpr(0, context);
auto d1 = getAffineDimExpr(1, context); auto d1 = getAffineDimExpr(1, context);
auto map = AffineMap::get(2, 0, {d0 + d1}, {}); auto map = AffineMap::get(2, 0, {d0 + d1}, {});
return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w}) return b->create<AffineApplyOp>(loc, map, ArrayRef<mlir::Value *>{v, w})
->getResult(0); ->getResult(0);
} }
namespace { namespace {
struct LowerVectorTransfersState : public MLFuncGlobalLoweringState { struct LowerVectorTransfersState : public MLFuncGlobalLoweringState {
// Top of the function constant zero index. // Top of the function constant zero index.
SSAValue *zero; Value *zero;
}; };
} // namespace } // namespace
@ -131,7 +128,8 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// case of GPUs. // case of GPUs.
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) { if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(), b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
vecView->getResult(), ArrayRef<SSAValue *>{state->zero}); vecView->getResult(),
ArrayRef<mlir::Value *>{state->zero});
} }
// 3. Emit the loop-nest. // 3. Emit the loop-nest.
@ -140,7 +138,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// TODO(ntv): Handle broadcast / slice properly. // TODO(ntv): Handle broadcast / slice properly.
auto permutationMap = transfer->getPermutationMap(); auto permutationMap = transfer->getPermutationMap();
SetVector<ForStmt *> loops; SetVector<ForStmt *> loops;
SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices()); SmallVector<Value *, 8> accessIndices(transfer->getIndices());
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) { for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
auto composed = composeWithUnboundedMap( auto composed = composeWithUnboundedMap(
getAffineDimExpr(it.index(), b.getContext()), permutationMap); getAffineDimExpr(it.index(), b.getContext()), permutationMap);
@ -168,17 +166,16 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// b. write scalar to local. // b. write scalar to local.
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(), auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
transfer->getMemRef(), accessIndices); transfer->getMemRef(), accessIndices);
b.create<StoreOp>( b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
transfer->getLoc(), scalarLoad->getResult(),
tmpScalarAlloc->getResult(), tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops)); functional::map([](Value *val) { return val; }, loops));
} else { } else {
// VectorTransferWriteOp. // VectorTransferWriteOp.
// a. read scalar from local; // a. read scalar from local;
// b. write scalar to remote. // b. write scalar to remote.
auto scalarLoad = b.create<LoadOp>( auto scalarLoad = b.create<LoadOp>(
transfer->getLoc(), tmpScalarAlloc->getResult(), transfer->getLoc(), tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops)); functional::map([](Value *val) { return val; }, loops));
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(), b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
transfer->getMemRef(), accessIndices); transfer->getMemRef(), accessIndices);
} }
@ -186,11 +183,11 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// 5. Read the vector from local storage in case of a vector_transfer_read. // 5. Read the vector from local storage in case of a vector_transfer_read.
// TODO(ntv): This vector_load operation should be further lowered in the // TODO(ntv): This vector_load operation should be further lowered in the
// case of GPUs. // case of GPUs.
llvm::SmallVector<SSAValue *, 1> newResults = {}; llvm::SmallVector<Value *, 1> newResults = {};
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) { if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation())); b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(), auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(),
ArrayRef<SSAValue *>{state->zero}) ArrayRef<Value *>{state->zero})
->getResult(); ->getResult();
newResults.push_back(vector); newResults.push_back(vector);
} }

View File

@ -32,9 +32,7 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Pass.h" #include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
@ -192,7 +190,7 @@ struct MaterializationState {
VectorType superVectorType; VectorType superVectorType;
VectorType hwVectorType; VectorType hwVectorType;
SmallVector<unsigned, 8> hwVectorInstance; SmallVector<unsigned, 8> hwVectorInstance;
DenseMap<const MLValue *, MLValue *> *substitutionsMap; DenseMap<const Value *, Value *> *substitutionsMap;
}; };
struct MaterializeVectorsPass : public FunctionPass { struct MaterializeVectorsPass : public FunctionPass {
@ -250,9 +248,9 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
static OperationStmt * static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap); DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all SSAValue belong to a program slice scoped within the immediately /// Not all Values belong to a program slice scoped within the immediately
/// enclosing loop. /// enclosing loop.
/// One simple example is constants defined outside the innermost loop scope. /// One simple example is constants defined outside the innermost loop scope.
/// For such cases the substitutionsMap has no entry and we allow an additional /// For such cases the substitutionsMap has no entry and we allow an additional
@ -261,17 +259,16 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
/// indices and will need to be extended in the future. /// indices and will need to be extended in the future.
/// ///
/// If substitution fails, returns nullptr. /// If substitution fails, returns nullptr.
static MLValue * static Value *substitute(Value *v, VectorType hwVectorType,
substitute(SSAValue *v, VectorType hwVectorType, DenseMap<const Value *, Value *> *substitutionsMap) {
DenseMap<const MLValue *, MLValue *> *substitutionsMap) { auto it = substitutionsMap->find(v);
auto it = substitutionsMap->find(cast<MLValue>(v));
if (it == substitutionsMap->end()) { if (it == substitutionsMap->end()) {
auto *opStmt = cast<OperationStmt>(v->getDefiningOperation()); auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
if (opStmt->isa<ConstantOp>()) { if (opStmt->isa<ConstantOp>()) {
MLFuncBuilder b(opStmt); MLFuncBuilder b(opStmt);
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap); auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
auto res = substitutionsMap->insert( auto res =
std::make_pair(cast<MLValue>(v), cast<MLValue>(inst->getResult(0)))); substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
assert(res.second && "Insertion failed"); assert(res.second && "Insertion failed");
return res.first->second; return res.first->second;
} }
@ -336,10 +333,10 @@ substitute(SSAValue *v, VectorType hwVectorType,
/// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): support a concrete AffineMap and compose with it.
/// TODO(ntv): these implementation details should be captured in a /// TODO(ntv): these implementation details should be captured in a
/// vectorization trait at the op level directly. /// vectorization trait at the op level directly.
static SmallVector<SSAValue *, 8> static SmallVector<mlir::Value *, 8>
reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance, ArrayRef<unsigned> hwVectorInstance,
ArrayRef<SSAValue *> memrefIndices) { ArrayRef<Value *> memrefIndices) {
auto vectorShape = hwVectorType.getShape(); auto vectorShape = hwVectorType.getShape();
assert(hwVectorInstance.size() >= vectorShape.size()); assert(hwVectorInstance.size() >= vectorShape.size());
@ -380,7 +377,7 @@ reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
// TODO(ntv): support a concrete map and composition. // TODO(ntv): support a concrete map and composition.
auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(), auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(),
affineMap, memrefIndices); affineMap, memrefIndices);
return SmallVector<SSAValue *, 8>{app->getResults()}; return SmallVector<mlir::Value *, 8>{app->getResults()};
} }
/// Returns attributes with the following substitutions applied: /// Returns attributes with the following substitutions applied:
@ -402,21 +399,21 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
/// Creates an instantiated version of `opStmt`. /// Creates an instantiated version of `opStmt`.
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
/// affine reindexing. Just substitute their SSAValue* operands and be done. For /// affine reindexing. Just substitute their Value operands and be done. For
/// this case the actual instance is irrelevant. Just use the SSA values in /// this case the actual instance is irrelevant. Just use the values in
/// substitutionsMap. /// substitutionsMap.
/// ///
/// If the underlying substitution fails, this fails too and returns nullptr. /// If the underlying substitution fails, this fails too and returns nullptr.
static OperationStmt * static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) { DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opStmt->isa<VectorTransferReadOp>() && assert(!opStmt->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp"); "Should call the function specialized for VectorTransferReadOp");
assert(!opStmt->isa<VectorTransferWriteOp>() && assert(!opStmt->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp"); "Should call the function specialized for VectorTransferWriteOp");
bool fail = false; bool fail = false;
auto operands = map( auto operands = map(
[hwVectorType, substitutionsMap, &fail](SSAValue *v) -> SSAValue * { [hwVectorType, substitutionsMap, &fail](Value *v) -> Value * {
auto *res = auto *res =
fail ? nullptr : substitute(v, hwVectorType, substitutionsMap); fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
fail |= !res; fail |= !res;
@ -481,9 +478,9 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
static OperationStmt * static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) { DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<SSAValue *, 8> indices = SmallVector<Value *, 8> indices =
map(makePtrDynCaster<SSAValue>(), read->getIndices()); map(makePtrDynCaster<Value>(), read->getIndices());
auto affineIndices = auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto cloned = b->create<VectorTransferReadOp>( auto cloned = b->create<VectorTransferReadOp>(
@ -501,9 +498,9 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
static OperationStmt * static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) { DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<SSAValue *, 8> indices = SmallVector<Value *, 8> indices =
map(makePtrDynCaster<SSAValue>(), write->getIndices()); map(makePtrDynCaster<Value>(), write->getIndices());
auto affineIndices = auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto cloned = b->create<VectorTransferWriteOp>( auto cloned = b->create<VectorTransferWriteOp>(
@ -555,8 +552,8 @@ static bool instantiateMaterialization(Statement *stmt,
} else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType, auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap); state->hwVectorInstance, state->substitutionsMap);
state->substitutionsMap->insert(std::make_pair( state->substitutionsMap->insert(
cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0)))); std::make_pair(read->getResult(), clone->getResult(0)));
return false; return false;
} }
// The only op with 0 results reaching this point must, by construction, be // The only op with 0 results reaching this point must, by construction, be
@ -571,8 +568,8 @@ static bool instantiateMaterialization(Statement *stmt,
if (!clone) { if (!clone) {
return true; return true;
} }
state->substitutionsMap->insert(std::make_pair( state->substitutionsMap->insert(
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); std::make_pair(opStmt->getResult(0), clone->getResult(0)));
return false; return false;
} }
@ -610,7 +607,7 @@ static bool emitSlice(MaterializationState *state,
// Fresh RAII instanceIndices and substitutionsMap. // Fresh RAII instanceIndices and substitutionsMap.
MaterializationState scopedState = *state; MaterializationState scopedState = *state;
scopedState.hwVectorInstance = delinearize(idx, *ratio); scopedState.hwVectorInstance = delinearize(idx, *ratio);
DenseMap<const MLValue *, MLValue *> substitutionMap; DenseMap<const Value *, Value *> substitutionMap;
scopedState.substitutionsMap = &substitutionMap; scopedState.substitutionsMap = &substitutionMap;
// slice are topologically sorted, we can just clone them in order. // slice are topologically sorted, we can just clone them in order.
for (auto *stmt : *slice) { for (auto *stmt : *slice) {

View File

@ -32,7 +32,6 @@
#include "mlir/Transforms/Utils.h" #include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pipeline-data-transfer" #define DEBUG_TYPE "pipeline-data-transfer"
using namespace mlir; using namespace mlir;
@ -80,7 +79,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// of the old memref by the new one while indexing the newly added dimension by /// of the old memref by the new one while indexing the newly added dimension by
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such /// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed. /// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto *forBody = forStmt->getBody(); auto *forBody = forStmt->getBody();
MLFuncBuilder bInner(forBody, forBody->begin()); MLFuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin());
@ -103,7 +102,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// Put together alloc operands for the dynamic dimensions of the memref. // Put together alloc operands for the dynamic dimensions of the memref.
MLFuncBuilder bOuter(forStmt); MLFuncBuilder bOuter(forStmt);
SmallVector<SSAValue *, 4> allocOperands; SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0; unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) { for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1) if (dimSize == -1)
@ -114,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// Create and place the alloc right before the 'for' statement. // Create and place the alloc right before the 'for' statement.
// TODO(mlir-team): we are assuming scoped allocation here, and aren't // TODO(mlir-team): we are assuming scoped allocation here, and aren't
// inserting a dealloc -- this isn't the right thing. // inserting a dealloc -- this isn't the right thing.
SSAValue *newMemRef = Value *newMemRef =
bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands); bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension. // Create 'iv mod 2' value to index the leading dimension.
@ -126,8 +125,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// replaceAllMemRefUsesWith will always succeed unless the forStmt body has // replaceAllMemRefUsesWith will always succeed unless the forStmt body has
// non-deferencing uses of the memref. // non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef), if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
ivModTwoOp->getResult(0), AffineMap::Null(), {}, AffineMap::Null(), {},
&*forStmt->getBody()->begin())) { &*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";); << "memref replacement for double buffering failed\n";);
@ -225,8 +224,7 @@ static void findMatchingStartFinishStmts(
continue; continue;
// We only double buffer if the buffer is not live out of loop. // We only double buffer if the buffer is not live out of loop.
const MLValue *memref = auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false; bool escapingUses = false;
for (const auto &use : memref->getUses()) { for (const auto &use : memref->getUses()) {
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) { if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
@ -280,8 +278,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// dimension. // dimension.
for (auto &pair : startWaitPairs) { for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first; auto *dmaStartStmt = pair.first;
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( Value *oldMemRef = dmaStartStmt->getOperand(
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos())); dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
if (!doubleBuffer(oldMemRef, forStmt)) { if (!doubleBuffer(oldMemRef, forStmt)) {
// Normally, double buffering should not fail because we already checked // Normally, double buffering should not fail because we already checked
// that there are no uses outside. // that there are no uses outside.
@ -302,8 +300,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// Double the buffers for tag memrefs. // Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) { for (auto &pair : startWaitPairs) {
auto *dmaFinishStmt = pair.second; auto *dmaFinishStmt = pair.second;
MLValue *oldTagMemRef = cast<MLValue>( Value *oldTagMemRef =
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
if (!doubleBuffer(oldTagMemRef, forStmt)) { if (!doubleBuffer(oldTagMemRef, forStmt)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success(); return success();
@ -332,7 +330,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// If a slice wasn't created, the reachable affine_apply op's from its // If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it. // operands are the ones that go with it.
SmallVector<OperationStmt *, 4> affineApplyStmts; SmallVector<OperationStmt *, 4> affineApplyStmts;
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands()); SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts); getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) { for (const auto *stmt : affineApplyStmts) {
stmtShiftMap[stmt] = 0; stmtShiftMap[stmt] = 0;

View File

@ -217,7 +217,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// If we already have a canonicalized version of this constant, just // If we already have a canonicalized version of this constant, just
// reuse it. Otherwise create a new one. // reuse it. Otherwise create a new one.
SSAValue *cstValue; Value *cstValue;
auto it = uniquedConstants.find({resultConstants[i], res->getType()}); auto it = uniquedConstants.find({resultConstants[i], res->getType()});
if (it != uniquedConstants.end()) if (it != uniquedConstants.end())
cstValue = it->second->getResult(0); cstValue = it->second->getResult(0);

View File

@ -31,7 +31,6 @@
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#define DEBUG_TYPE "LoopUtils" #define DEBUG_TYPE "LoopUtils"
using namespace mlir; using namespace mlir;
@ -108,8 +107,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
forStmt->replaceAllUsesWith(constOp); forStmt->replaceAllUsesWith(constOp);
} else { } else {
const AffineBound lb = forStmt->getLowerBound(); const AffineBound lb = forStmt->getLowerBound();
SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(), SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
lb.operand_end());
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt)); MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
auto affineApplyOp = builder.create<AffineApplyOp>( auto affineApplyOp = builder.create<AffineApplyOp>(
forStmt->getLoc(), lb.getMap(), lbOperands); forStmt->getLoc(), lb.getMap(), lbOperands);
@ -149,8 +147,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue, &stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) { unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands()); SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands()); SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size()); assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size()); assert(ubMap.getNumInputs() == ubOperands.size());
@ -176,7 +174,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
srcForStmt->getStep() * shift)), srcForStmt->getStep() * shift)),
loopChunk) loopChunk)
->getResult(0); ->getResult(0);
operandMap[srcForStmt] = cast<MLValue>(ivRemap); operandMap[srcForStmt] = ivRemap;
} else { } else {
operandMap[srcForStmt] = loopChunk; operandMap[srcForStmt] = loopChunk;
} }
@ -380,7 +378,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor. // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) { if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap; DenseMap<const Value *, Value *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder); auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
@ -414,7 +412,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies). // Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) { for (unsigned i = 1; i < unrollFactor; i++) {
DenseMap<const MLValue *, MLValue *> operandMap; DenseMap<const Value *, Value *> operandMap;
// If the induction variable is used, create a remapping to the value for // If the induction variable is used, create a remapping to the value for
// this unrolled instance. // this unrolled instance.
@ -425,7 +423,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
auto *ivUnroll = auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0); ->getResult(0);
operandMap[forStmt] = cast<MLValue>(ivUnroll); operandMap[forStmt] = ivUnroll;
} }
// Clone the original body of 'forStmt'. // Clone the original body of 'forStmt'.

View File

@ -32,17 +32,17 @@ using namespace mlir;
namespace { namespace {
// Visit affine expressions recursively and build the sequence of instructions // Visit affine expressions recursively and build the sequence of instructions
// that correspond to it. Visitation functions return an SSAValue of the // that correspond to it. Visitation functions return an Value of the
// expression subtree they visited or `nullptr` on error. // expression subtree they visited or `nullptr` on error.
class AffineApplyExpander class AffineApplyExpander
: public AffineExprVisitor<AffineApplyExpander, SSAValue *> { : public AffineExprVisitor<AffineApplyExpander, Value *> {
public: public:
// This internal clsas expects arguments to be non-null, checks must be // This internal clsas expects arguments to be non-null, checks must be
// performed at the call site. // performed at the call site.
AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op) AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op)
: builder(*builder), applyOp(*op), loc(op->getLoc()) {} : builder(*builder), applyOp(*op), loc(op->getLoc()) {}
template <typename OpTy> SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) { template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS()); auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS()); auto rhs = visit(expr.getRHS());
if (!lhs || !rhs) if (!lhs || !rhs)
@ -51,33 +51,33 @@ public:
return op->getResult(); return op->getResult();
} }
SSAValue *visitAddExpr(AffineBinaryOpExpr expr) { Value *visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<AddIOp>(expr); return buildBinaryExpr<AddIOp>(expr);
} }
SSAValue *visitMulExpr(AffineBinaryOpExpr expr) { Value *visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<MulIOp>(expr); return buildBinaryExpr<MulIOp>(expr);
} }
// TODO(zinenko): implement when the standard operators are made available. // TODO(zinenko): implement when the standard operators are made available.
SSAValue *visitModExpr(AffineBinaryOpExpr) { Value *visitModExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc, "unsupported binary operator: mod"); builder.getContext()->emitError(loc, "unsupported binary operator: mod");
return nullptr; return nullptr;
} }
SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) { Value *visitFloorDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc, builder.getContext()->emitError(loc,
"unsupported binary operator: floor_div"); "unsupported binary operator: floor_div");
return nullptr; return nullptr;
} }
SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) { Value *visitCeilDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc, builder.getContext()->emitError(loc,
"unsupported binary operator: ceil_div"); "unsupported binary operator: ceil_div");
return nullptr; return nullptr;
} }
SSAValue *visitConstantExpr(AffineConstantExpr expr) { Value *visitConstantExpr(AffineConstantExpr expr) {
auto valueAttr = auto valueAttr =
builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
auto op = auto op =
@ -85,7 +85,7 @@ public:
return op->getResult(); return op->getResult();
} }
SSAValue *visitDimExpr(AffineDimExpr expr) { Value *visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < applyOp.getNumOperands() && assert(expr.getPosition() < applyOp.getNumOperands() &&
"affine dim position out of range"); "affine dim position out of range");
// FIXME: this assumes a certain order of AffineApplyOp operands, the // FIXME: this assumes a certain order of AffineApplyOp operands, the
@ -93,7 +93,7 @@ public:
return applyOp.getOperand(expr.getPosition()); return applyOp.getOperand(expr.getPosition());
} }
SSAValue *visitSymbolExpr(AffineSymbolExpr expr) { Value *visitSymbolExpr(AffineSymbolExpr expr) {
// FIXME: this assumes a certain order of AffineApplyOp operands, the // FIXME: this assumes a certain order of AffineApplyOp operands, the
// cleaner interface would be to separate them at the op level. // cleaner interface would be to separate them at the op level.
assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() < assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() <
@ -114,7 +114,7 @@ private:
// Given an affine expression `expr` extracted from `op`, build the sequence of // Given an affine expression `expr` extracted from `op`, build the sequence of
// primitive instructions that correspond to the affine expression in the // primitive instructions that correspond to the affine expression in the
// `builder`. // `builder`.
static SSAValue *expandAffineExpr(FuncBuilder *builder, AffineExpr expr, static mlir::Value *expandAffineExpr(FuncBuilder *builder, AffineExpr expr,
AffineApplyOp *op) { AffineApplyOp *op) {
auto expander = AffineApplyExpander(builder, op); auto expander = AffineApplyExpander(builder, op);
return expander.visit(expr); return expander.visit(expr);
@ -127,7 +127,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
FuncBuilder builder(op->getOperation()); FuncBuilder builder(op->getOperation());
auto affineMap = op->getAffineMap(); auto affineMap = op->getAffineMap();
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) { for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
SSAValue *expanded = expandAffineExpr(&builder, numberedExpr.value(), op); Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
if (!expanded) if (!expanded)
return true; return true;
op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded); op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded);

View File

@ -31,7 +31,6 @@
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
using namespace mlir; using namespace mlir;
/// Return true if this operation dereferences one or more memref's. /// Return true if this operation dereferences one or more memref's.
@ -61,13 +60,12 @@ static bool isMemRefDereferencingOp(const Operation &op) {
// extra operands, note that 'indexRemap' would just be applied to the existing // extra operands, note that 'indexRemap' would just be applied to the existing
// indices (%i, %j). // indices (%i, %j).
// //
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily // TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
// extended to add additional indices at any position. // extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef, bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
MLValue *newMemRef, ArrayRef<Value *> extraIndices,
ArrayRef<SSAValue *> extraIndices,
AffineMap indexRemap, AffineMap indexRemap,
ArrayRef<SSAValue *> extraOperands, ArrayRef<Value *> extraOperands,
const Statement *domStmtFilter) { const Statement *domStmtFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode (void)newMemRefRank; // unused in opt mode
@ -128,16 +126,15 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
// operation. // operation.
assert(extraIndex->getDefiningStmt()->getNumResults() == 1 && assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
"single result op's expected to generate these indices"); "single result op's expected to generate these indices");
assert((cast<MLValue>(extraIndex)->isValidDim() || assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) &&
cast<MLValue>(extraIndex)->isValidSymbol()) &&
"invalid memory op index"); "invalid memory op index");
state.operands.push_back(cast<MLValue>(extraIndex)); state.operands.push_back(extraIndex);
} }
// Construct new indices as a remap of the old ones if a remapping has been // Construct new indices as a remap of the old ones if a remapping has been
// provided. The indices of a memref come right after it, i.e., // provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1. // at position memRefOperandPos + 1.
SmallVector<SSAValue *, 4> remapOperands; SmallVector<Value *, 4> remapOperands;
remapOperands.reserve(oldMemRefRank + extraOperands.size()); remapOperands.reserve(oldMemRefRank + extraOperands.size());
remapOperands.insert(remapOperands.end(), extraOperands.begin(), remapOperands.insert(remapOperands.end(), extraOperands.begin(),
extraOperands.end()); extraOperands.end());
@ -149,11 +146,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
remapOperands); remapOperands);
// Remapped indices. // Remapped indices.
for (auto *index : remapOp->getOperation()->getResults()) for (auto *index : remapOp->getOperation()->getResults())
state.operands.push_back(cast<MLValue>(index)); state.operands.push_back(index);
} else { } else {
// No remapping specified. // No remapping specified.
for (auto *index : remapOperands) for (auto *index : remapOperands)
state.operands.push_back(cast<MLValue>(index)); state.operands.push_back(index);
} }
// Insert the remaining operands unmodified. // Insert the remaining operands unmodified.
@ -191,9 +188,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
// composed AffineApplyOp are returned in output parameter 'results'. // composed AffineApplyOp are returned in output parameter 'results'.
OperationStmt * OperationStmt *
mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc, mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands, ArrayRef<Value *> operands,
ArrayRef<OperationStmt *> affineApplyOps, ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> *results) { SmallVectorImpl<Value *> *results) {
// Create identity map with same number of dimensions as number of operands. // Create identity map with same number of dimensions as number of operands.
auto map = builder->getMultiDimIdentityMap(operands.size()); auto map = builder->getMultiDimIdentityMap(operands.size());
// Initialize AffineValueMap with identity map. // Initialize AffineValueMap with identity map.
@ -208,7 +205,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
// Compose affine maps from all ancestor AffineApplyOps. // Compose affine maps from all ancestor AffineApplyOps.
// Create new AffineApplyOp from 'valueMap'. // Create new AffineApplyOp from 'valueMap'.
unsigned numOperands = valueMap.getNumOperands(); unsigned numOperands = valueMap.getNumOperands();
SmallVector<SSAValue *, 4> outOperands(numOperands); SmallVector<Value *, 4> outOperands(numOperands);
for (unsigned i = 0; i < numOperands; ++i) { for (unsigned i = 0; i < numOperands; ++i) {
outOperands[i] = valueMap.getOperand(i); outOperands[i] = valueMap.getOperand(i);
} }
@ -252,7 +249,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
/// otherwise. /// otherwise.
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) { OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// Collect all operands that are results of affine apply ops. // Collect all operands that are results of affine apply ops.
SmallVector<MLValue *, 4> subOperands; SmallVector<Value *, 4> subOperands;
subOperands.reserve(opStmt->getNumOperands()); subOperands.reserve(opStmt->getNumOperands());
for (auto *operand : opStmt->getOperands()) { for (auto *operand : opStmt->getOperands()) {
auto *defStmt = operand->getDefiningStmt(); auto *defStmt = operand->getDefiningStmt();
@ -285,7 +282,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
return nullptr; return nullptr;
FuncBuilder builder(opStmt); FuncBuilder builder(opStmt);
SmallVector<SSAValue *, 4> results; SmallVector<Value *, 4> results;
auto *affineApplyStmt = createComposedAffineApplyOp( auto *affineApplyStmt = createComposedAffineApplyOp(
&builder, opStmt->getLoc(), subOperands, affineApplyOps, &results); &builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
assert(results.size() == subOperands.size() && assert(results.size() == subOperands.size() &&
@ -295,7 +292,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// affine apply op above instead of existing ones (subOperands). So, they // affine apply op above instead of existing ones (subOperands). So, they
// differ from opStmt's operands only for those operands in 'subOperands', for // differ from opStmt's operands only for those operands in 'subOperands', for
// which they will be replaced by the corresponding one from 'results'. // which they will be replaced by the corresponding one from 'results'.
SmallVector<MLValue *, 4> newOperands(opStmt->getOperands()); SmallVector<Value *, 4> newOperands(opStmt->getOperands());
for (unsigned i = 0, e = newOperands.size(); i < e; i++) { for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
// Replace the subOperands from among the new operands. // Replace the subOperands from among the new operands.
unsigned j, f; unsigned j, f;
@ -304,7 +301,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
break; break;
} }
if (j < subOperands.size()) { if (j < subOperands.size()) {
newOperands[i] = cast<MLValue>(results[j]); newOperands[i] = results[j];
} }
} }
@ -326,7 +323,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// into any uses which are AffineApplyOps. // into any uses which are AffineApplyOps.
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e; for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
++resultIndex) { ++resultIndex) {
const MLValue *result = opStmt->getResult(resultIndex); const Value *result = opStmt->getResult(resultIndex);
for (auto it = result->use_begin(); it != result->use_end();) { for (auto it = result->use_begin(); it != result->use_end();) {
StmtOperand &use = *(it++); StmtOperand &use = *(it++);
auto *useStmt = use.getOwner(); auto *useStmt = use.getOwner();
@ -347,7 +344,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// Create new AffineApplyOp from 'valueMap'. // Create new AffineApplyOp from 'valueMap'.
unsigned numOperands = valueMap.getNumOperands(); unsigned numOperands = valueMap.getNumOperands();
SmallVector<SSAValue *, 4> operands(numOperands); SmallVector<Value *, 4> operands(numOperands);
for (unsigned i = 0; i < numOperands; ++i) { for (unsigned i = 0; i < numOperands; ++i) {
operands[i] = valueMap.getOperand(i); operands[i] = valueMap.getOperand(i);
} }

View File

@ -27,8 +27,6 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Pass.h" #include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h" #include "mlir/StandardOps/StandardOps.h"
@ -740,8 +738,8 @@ struct VectorizationState {
DenseSet<OperationStmt *> vectorizedSet; DenseSet<OperationStmt *> vectorizedSet;
// Map of old scalar OperationStmt to new vectorized OperationStmt. // Map of old scalar OperationStmt to new vectorized OperationStmt.
DenseMap<OperationStmt *, OperationStmt *> vectorizationMap; DenseMap<OperationStmt *, OperationStmt *> vectorizationMap;
// Map of old scalar MLValue to new vectorized MLValue. // Map of old scalar Value to new vectorized Value.
DenseMap<const MLValue *, MLValue *> replacementMap; DenseMap<const Value *, Value *> replacementMap;
// The strategy drives which loop to vectorize by which amount. // The strategy drives which loop to vectorize by which amount.
const VectorizationStrategy *strategy; const VectorizationStrategy *strategy;
// Use-def roots. These represent the starting points for the worklist in the // Use-def roots. These represent the starting points for the worklist in the
@ -761,7 +759,7 @@ struct VectorizationState {
void registerTerminator(OperationStmt *stmt); void registerTerminator(OperationStmt *stmt);
private: private:
void registerReplacement(const SSAValue *key, SSAValue *value); void registerReplacement(const Value *key, Value *value);
}; };
} // end namespace } // end namespace
@ -802,12 +800,9 @@ void VectorizationState::finishVectorizationPattern() {
} }
} }
void VectorizationState::registerReplacement(const SSAValue *key, void VectorizationState::registerReplacement(const Value *key, Value *value) {
SSAValue *value) { assert(replacementMap.count(key) == 0 && "replacement already registered");
assert(replacementMap.count(cast<MLValue>(key)) == 0 && replacementMap.insert(std::make_pair(key, value));
"replacement already registered");
replacementMap.insert(
std::make_pair(cast<MLValue>(key), cast<MLValue>(value)));
} }
////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// ////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ////
@ -825,7 +820,7 @@ void VectorizationState::registerReplacement(const SSAValue *key,
/// Such special cases force us to delay the vectorization of the stores /// Such special cases force us to delay the vectorization of the stores
/// until the last step. Here we merely register the store operation. /// until the last step. Here we merely register the store operation.
template <typename LoadOrStoreOpPointer> template <typename LoadOrStoreOpPointer>
static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
VectorizationState *state) { VectorizationState *state) {
auto memRefType = auto memRefType =
memoryOp->getMemRef()->getType().template cast<MemRefType>(); memoryOp->getMemRef()->getType().template cast<MemRefType>();
@ -850,8 +845,7 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
MLFuncBuilder b(opStmt); MLFuncBuilder b(opStmt);
auto transfer = b.create<VectorTransferReadOp>( auto transfer = b.create<VectorTransferReadOp>(
opStmt->getLoc(), vectorType, memoryOp->getMemRef(), opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()), map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
permutationMap);
state->registerReplacement(opStmt, state->registerReplacement(opStmt,
cast<OperationStmt>(transfer->getOperation())); cast<OperationStmt>(transfer->getOperation()));
} else { } else {
@ -970,7 +964,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// element type. /// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a /// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr. /// valid vector element type, returns nullptr.
static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
Type type) { Type type) {
if (!type || !type.isa<VectorType>() || if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) { !VectorType::isValidElementType(constant.getType())) {
@ -988,7 +982,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
{make_pair(Identifier::get("value", b.getContext()), attr)}); {make_pair(Identifier::get("value", b.getContext()), attr)});
auto *splat = cast<OperationStmt>(b.createOperation(state)); auto *splat = cast<OperationStmt>(b.createOperation(state));
return cast<MLValue>(splat->getResult(0)); return splat->getResult(0);
} }
/// Returns a uniqu'ed VectorType. /// Returns a uniqu'ed VectorType.
@ -996,7 +990,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
/// vectorizedSet, just returns the type of `v`. /// vectorizedSet, just returns the type of `v`.
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
/// and of elemental type the type of `v`. /// and of elemental type the type of `v`.
static Type getVectorType(SSAValue *v, const VectorizationState &state) { static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) { if (!VectorType::isValidElementType(v->getType())) {
return Type(); return Type();
} }
@ -1028,7 +1022,7 @@ static Type getVectorType(SSAValue *v, const VectorizationState &state) {
/// vectorization is possible with the above logic. Returns nullptr otherwise. /// vectorization is possible with the above logic. Returns nullptr otherwise.
/// ///
/// TODO(ntv): handle more complex cases. /// TODO(ntv): handle more complex cases.
static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt, static Value *vectorizeOperand(Value *operand, Statement *stmt,
VectorizationState *state) { VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs())); LLVM_DEBUG(operand->print(dbgs()));
@ -1036,15 +1030,15 @@ static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt,
// 1. If this value has already been vectorized this round, we are done. // 1. If this value has already been vectorized this round, we are done.
if (state->vectorizedSet.count(definingStatement) > 0) { if (state->vectorizedSet.count(definingStatement) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand"); LLVM_DEBUG(dbgs() << " -> already vector operand");
return cast<MLValue>(operand); return operand;
} }
// 1.b. Delayed on-demand replacement of a use. // 1.b. Delayed on-demand replacement of a use.
// Note that we cannot just call replaceAllUsesWith because it may result // Note that we cannot just call replaceAllUsesWith because it may result
// in ops with mixed types, for ops whose operands have not all yet // in ops with mixed types, for ops whose operands have not all yet
// been vectorized. This would be invalid IR. // been vectorized. This would be invalid IR.
auto it = state->replacementMap.find(cast<MLValue>(operand)); auto it = state->replacementMap.find(operand);
if (it != state->replacementMap.end()) { if (it != state->replacementMap.end()) {
auto *res = cast<MLValue>(it->second); auto *res = it->second;
LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
LLVM_DEBUG(res->print(dbgs())); LLVM_DEBUG(res->print(dbgs()));
return res; return res;
@ -1089,7 +1083,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
auto *memRef = store->getMemRef(); auto *memRef = store->getMemRef();
auto *value = store->getValueToStore(); auto *value = store->getValueToStore();
auto *vectorValue = vectorizeOperand(value, opStmt, state); auto *vectorValue = vectorizeOperand(value, opStmt, state);
auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices()); auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
MLFuncBuilder b(opStmt); MLFuncBuilder b(opStmt);
auto permutationMap = auto permutationMap =
makePermutationMap(opStmt, state->strategy->loopToVectorDim); makePermutationMap(opStmt, state->strategy->loopToVectorDim);
@ -1104,14 +1098,14 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
return res; return res;
} }
auto types = map([state](SSAValue *v) { return getVectorType(v, *state); }, auto types = map([state](Value *v) { return getVectorType(v, *state); },
opStmt->getResults()); opStmt->getResults());
auto vectorizeOneOperand = [opStmt, state](SSAValue *op) -> SSAValue * { auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
return vectorizeOperand(op, opStmt, state); return vectorizeOperand(op, opStmt, state);
}; };
auto operands = map(vectorizeOneOperand, opStmt->getOperands()); auto operands = map(vectorizeOneOperand, opStmt->getOperands());
// Check whether a single operand is null. If so, vectorization failed. // Check whether a single operand is null. If so, vectorization failed.
bool success = llvm::all_of(operands, [](SSAValue *op) { return op; }); bool success = llvm::all_of(operands, [](Value *op) { return op; });
if (!success) { if (!success) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
return nullptr; return nullptr;
@ -1207,7 +1201,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
continue; continue;
} }
MLFuncBuilder builder(loop); // builder to insert in place of loop MLFuncBuilder builder(loop); // builder to insert in place of loop
DenseMap<const MLValue *, MLValue *> nomap; DenseMap<const Value *, Value *> nomap;
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap)); ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
auto fail = doVectorize(m, &state); auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match /// Sets up error handling for this root loop. This is how the root match
@ -1229,8 +1223,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
// Form the root operationsthat have been set in the replacementMap. // Form the root operationsthat have been set in the replacementMap.
// For now, these roots are the loads for which vector_transfer_read // For now, these roots are the loads for which vector_transfer_read
// operations have been inserted. // operations have been inserted.
auto getDefiningOperation = [](const MLValue *val) { auto getDefiningOperation = [](const Value *val) {
return const_cast<MLValue *>(val)->getDefiningOperation(); return const_cast<Value *>(val)->getDefiningOperation();
}; };
using ReferenceTy = decltype(*(state.replacementMap.begin())); using ReferenceTy = decltype(*(state.replacementMap.begin()));
auto getKey = [](ReferenceTy it) { return it.first; }; auto getKey = [](ReferenceTy it) { return it.first; };

View File

@ -288,10 +288,10 @@ void OpEmitter::emitAttrGetters() {
} }
void OpEmitter::emitNamedOperands() { void OpEmitter::emitNamedOperands() {
const auto operandMethods = R"( SSAValue *{0}() { const auto operandMethods = R"( Value *{0}() {
return this->getOperation()->getOperand({1}); return this->getOperation()->getOperand({1});
} }
const SSAValue *{0}() const { const Value *{0}() const {
return this->getOperation()->getOperand({1}); return this->getOperation()->getOperand({1});
} }
)"; )";
@ -329,7 +329,7 @@ void OpEmitter::emitBuilder() {
// Emit parameters for all operands // Emit parameters for all operands
for (const auto &pair : operands) for (const auto &pair : operands)
os << ", SSAValue* " << pair.first; os << ", Value* " << pair.first;
// Emit parameters for all attributes // Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes // TODO(antiagainst): Support default initializer for attributes
@ -369,7 +369,7 @@ void OpEmitter::emitBuilder() {
// Signature // Signature
os << " static void build(Builder* builder, OperationState* result, " os << " static void build(Builder* builder, OperationState* result, "
<< "ArrayRef<Type> resultTypes, ArrayRef<SSAValue*> args, " << "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
"ArrayRef<NamedAttribute> attributes) {\n"; "ArrayRef<NamedAttribute> attributes) {\n";
// Result types // Result types