Merge SSAValue, CFGValue, and MLValue together into a single Value class, which
is the new base of the SSA value hierarchy. This CL also standardizes all the nomenclature and comments to use 'Value' where appropriate. This also eliminates a large number of cast<MLValue>(x)'s, which is very soothing. This is step 11/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227064624
This commit is contained in:
parent
776b035646
commit
3f190312f8
|
@ -37,9 +37,9 @@ class ForStmt;
|
|||
class MLIRContext;
|
||||
class FlatAffineConstraints;
|
||||
class IntegerSet;
|
||||
class MLValue;
|
||||
class OperationStmt;
|
||||
class Statement;
|
||||
class Value;
|
||||
|
||||
/// Simplify an affine expression through flattening and some amount of
|
||||
/// simple analysis. This has complexity linear in the number of nodes in
|
||||
|
@ -78,7 +78,7 @@ AffineExpr composeWithUnboundedMap(AffineExpr e, AffineMap g);
|
|||
/// 'affineApplyOps', which are reachable via a search starting from 'operands',
|
||||
/// and ending at operands which are not defined by AffineApplyOps.
|
||||
void getReachableAffineApplyOps(
|
||||
llvm::ArrayRef<MLValue *> operands,
|
||||
llvm::ArrayRef<Value *> operands,
|
||||
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
|
||||
|
||||
/// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
|
||||
|
@ -122,9 +122,9 @@ bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
|
|||
FlatAffineConstraints *domain);
|
||||
|
||||
struct MemRefAccess {
|
||||
const MLValue *memref;
|
||||
const Value *memref;
|
||||
const OperationStmt *opStmt;
|
||||
llvm::SmallVector<MLValue *, 4> indices;
|
||||
llvm::SmallVector<Value *, 4> indices;
|
||||
// Populates 'accessMap' with composition of AffineApplyOps reachable from
|
||||
// 'indices'.
|
||||
void getAccessMap(AffineValueMap *accessMap) const;
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -37,7 +36,7 @@ class AffineMap;
|
|||
class ForStmt;
|
||||
class IntegerSet;
|
||||
class MLIRContext;
|
||||
class MLValue;
|
||||
class Value;
|
||||
class HyperRectangularSet;
|
||||
|
||||
/// A mutable affine map. Its affine expressions are however unique.
|
||||
|
@ -132,7 +131,7 @@ public:
|
|||
AffineValueMap(const AffineApplyOp &op);
|
||||
AffineValueMap(const AffineBound &bound);
|
||||
AffineValueMap(AffineMap map);
|
||||
AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands);
|
||||
AffineValueMap(AffineMap map, ArrayRef<Value *> operands);
|
||||
|
||||
~AffineValueMap();
|
||||
|
||||
|
@ -155,13 +154,13 @@ public:
|
|||
// substitutions).
|
||||
|
||||
// Resets this AffineValueMap with 'map' and 'operands'.
|
||||
void reset(AffineMap map, ArrayRef<MLValue *> operands);
|
||||
void reset(AffineMap map, ArrayRef<Value *> operands);
|
||||
/// Return true if the idx^th result can be proved to be a multiple of
|
||||
/// 'factor', false otherwise.
|
||||
inline bool isMultipleOf(unsigned idx, int64_t factor) const;
|
||||
|
||||
/// Return true if the idx^th result depends on 'value', false otherwise.
|
||||
bool isFunctionOf(unsigned idx, MLValue *value) const;
|
||||
bool isFunctionOf(unsigned idx, Value *value) const;
|
||||
|
||||
/// Return true if the result at 'idx' is a constant, false
|
||||
/// otherwise.
|
||||
|
@ -175,8 +174,8 @@ public:
|
|||
inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
|
||||
inline unsigned getNumResults() const { return map.getNumResults(); }
|
||||
|
||||
SSAValue *getOperand(unsigned i) const;
|
||||
ArrayRef<MLValue *> getOperands() const;
|
||||
Value *getOperand(unsigned i) const;
|
||||
ArrayRef<Value *> getOperands() const;
|
||||
AffineMap getAffineMap() const;
|
||||
|
||||
private:
|
||||
|
@ -187,9 +186,9 @@ private:
|
|||
|
||||
// TODO: make these trailing objects?
|
||||
/// The SSA operands binding to the dim's and symbols of 'map'.
|
||||
SmallVector<MLValue *, 4> operands;
|
||||
SmallVector<Value *, 4> operands;
|
||||
/// The SSA results binding to the results of 'map'.
|
||||
SmallVector<MLValue *, 4> results;
|
||||
SmallVector<Value *, 4> results;
|
||||
};
|
||||
|
||||
/// An IntegerValueSet is an integer set plus its operands.
|
||||
|
@ -218,7 +217,7 @@ private:
|
|||
// 'AffineCondition'.
|
||||
MutableIntegerSet set;
|
||||
/// The SSA operands binding to the dim's and symbols of 'set'.
|
||||
SmallVector<MLValue *, 4> operands;
|
||||
SmallVector<Value *, 4> operands;
|
||||
};
|
||||
|
||||
/// A flat list of affine equalities and inequalities in the form.
|
||||
|
@ -250,7 +249,7 @@ public:
|
|||
unsigned numReservedEqualities,
|
||||
unsigned numReservedCols, unsigned numDims = 0,
|
||||
unsigned numSymbols = 0, unsigned numLocals = 0,
|
||||
ArrayRef<Optional<MLValue *>> idArgs = {})
|
||||
ArrayRef<Optional<Value *>> idArgs = {})
|
||||
: numReservedCols(numReservedCols), numDims(numDims),
|
||||
numSymbols(numSymbols) {
|
||||
assert(numReservedCols >= numDims + numSymbols + 1);
|
||||
|
@ -269,7 +268,7 @@ public:
|
|||
/// dimensions and symbols.
|
||||
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
|
||||
unsigned numLocals = 0,
|
||||
ArrayRef<Optional<MLValue *>> idArgs = {})
|
||||
ArrayRef<Optional<Value *>> idArgs = {})
|
||||
: numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
|
||||
numSymbols(numSymbols) {
|
||||
assert(numReservedCols >= numDims + numSymbols + 1);
|
||||
|
@ -309,10 +308,10 @@ public:
|
|||
// Clears any existing data and reserves memory for the specified constraints.
|
||||
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
|
||||
unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
|
||||
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
|
||||
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
|
||||
|
||||
void reset(unsigned numDims = 0, unsigned numSymbols = 0,
|
||||
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
|
||||
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
|
||||
|
||||
/// Appends constraints from 'other' into this. This is equivalent to an
|
||||
/// intersection with no simplification of any sort attempted.
|
||||
|
@ -393,7 +392,7 @@ public:
|
|||
// Returns AffineMap::Null on error (i.e. if coefficient is zero or does
|
||||
// not divide other coefficients in the equality constraint).
|
||||
// TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this
|
||||
// API when we can manage the mapping of MLValues and ids in the constraint
|
||||
// API when we can manage the mapping of Values and ids in the constraint
|
||||
// system.
|
||||
AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context,
|
||||
SmallVectorImpl<unsigned> *nonZeroDimIds,
|
||||
|
@ -413,10 +412,10 @@ public:
|
|||
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
|
||||
|
||||
/// Adds constraints (lower and upper bounds) for the specified 'for'
|
||||
/// statement's MLValue using IR information stored in its bound maps. The
|
||||
/// right identifier is first looked up using forStmt's MLValue. Returns
|
||||
/// statement's Value using IR information stored in its bound maps. The
|
||||
/// right identifier is first looked up using forStmt's Value. Returns
|
||||
/// false for the yet unimplemented/unsupported cases, and true if the
|
||||
/// information is succesfully added. Asserts if the MLValue corresponding to
|
||||
/// information is succesfully added. Asserts if the Value corresponding to
|
||||
/// the 'for' statement isn't found in the constraint system. Any new
|
||||
/// identifiers that are found in the bound operands of the 'for' statement
|
||||
/// are added as trailing identifiers (either dimensional or symbolic
|
||||
|
@ -435,28 +434,28 @@ public:
|
|||
/// Sets the identifier at the specified position to a constant.
|
||||
void setIdToConstant(unsigned pos, int64_t val);
|
||||
|
||||
/// Sets the identifier corresponding to the specified MLValue id to a
|
||||
/// Sets the identifier corresponding to the specified Value id to a
|
||||
/// constant. Asserts if the 'id' is not found.
|
||||
void setIdToConstant(const MLValue &id, int64_t val);
|
||||
void setIdToConstant(const Value &id, int64_t val);
|
||||
|
||||
/// Looks up the identifier with the specified MLValue. Returns false if not
|
||||
/// Looks up the identifier with the specified Value. Returns false if not
|
||||
/// found, true if found. pos is set to the (column) position of the
|
||||
/// identifier.
|
||||
bool findId(const MLValue &id, unsigned *pos) const;
|
||||
bool findId(const Value &id, unsigned *pos) const;
|
||||
|
||||
// Add identifiers of the specified kind - specified positions are relative to
|
||||
// the kind of identifier. 'id' is the MLValue corresponding to the
|
||||
// the kind of identifier. 'id' is the Value corresponding to the
|
||||
// identifier that can optionally be provided.
|
||||
void addDimId(unsigned pos, MLValue *id = nullptr);
|
||||
void addSymbolId(unsigned pos, MLValue *id = nullptr);
|
||||
void addDimId(unsigned pos, Value *id = nullptr);
|
||||
void addSymbolId(unsigned pos, Value *id = nullptr);
|
||||
void addLocalId(unsigned pos);
|
||||
void addId(IdKind kind, unsigned pos, MLValue *id = nullptr);
|
||||
void addId(IdKind kind, unsigned pos, Value *id = nullptr);
|
||||
|
||||
/// Composes the affine value map with this FlatAffineConstrains, adding the
|
||||
/// results of the map as dimensions at the front [0, vMap->getNumResults())
|
||||
/// and with the dimensions set to the equalities specified by the value map.
|
||||
/// Returns false if the composition fails (when vMap is a semi-affine map).
|
||||
/// The vMap's operand MLValue's are used to look up the right positions in
|
||||
/// The vMap's operand Value's are used to look up the right positions in
|
||||
/// the FlatAffineConstraints with which to associate. The dimensional and
|
||||
/// symbolic operands of vMap should match 1:1 (in the same order) with those
|
||||
/// of this constraint system, but the latter could have additional trailing
|
||||
|
@ -471,8 +470,8 @@ public:
|
|||
void projectOut(unsigned pos, unsigned num);
|
||||
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
|
||||
|
||||
/// Projects out the identifier that is associate with MLValue *.
|
||||
void projectOut(MLValue *id);
|
||||
/// Projects out the identifier that is associate with Value *.
|
||||
void projectOut(Value *id);
|
||||
|
||||
void removeId(IdKind idKind, unsigned pos);
|
||||
void removeId(unsigned pos);
|
||||
|
@ -510,24 +509,24 @@ public:
|
|||
return numIds - numDims - numSymbols;
|
||||
}
|
||||
|
||||
inline ArrayRef<Optional<MLValue *>> getIds() const {
|
||||
inline ArrayRef<Optional<Value *>> getIds() const {
|
||||
return {ids.data(), ids.size()};
|
||||
}
|
||||
|
||||
/// Returns the MLValue's associated with the identifiers. Asserts if
|
||||
/// no MLValue was associated with an identifier.
|
||||
inline void getIdValues(SmallVectorImpl<MLValue *> *values) const {
|
||||
/// Returns the Value's associated with the identifiers. Asserts if
|
||||
/// no Value was associated with an identifier.
|
||||
inline void getIdValues(SmallVectorImpl<Value *> *values) const {
|
||||
values->clear();
|
||||
values->reserve(numIds);
|
||||
for (unsigned i = 0; i < numIds; i++) {
|
||||
assert(ids[i].hasValue() && "identifier's MLValue not set");
|
||||
assert(ids[i].hasValue() && "identifier's Value not set");
|
||||
values->push_back(ids[i].getValue());
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the MLValue associated with the pos^th identifier. Asserts if
|
||||
/// no MLValue identifier was associated.
|
||||
inline MLValue *getIdValue(unsigned pos) const {
|
||||
/// Returns the Value associated with the pos^th identifier. Asserts if
|
||||
/// no Value identifier was associated.
|
||||
inline Value *getIdValue(unsigned pos) const {
|
||||
assert(ids[pos].hasValue() && "identifier's ML Value not set");
|
||||
return ids[pos].getValue();
|
||||
}
|
||||
|
@ -630,11 +629,11 @@ private:
|
|||
/// analysis).
|
||||
unsigned numSymbols;
|
||||
|
||||
/// MLValues corresponding to the (column) identifiers of this constraint
|
||||
/// Values corresponding to the (column) identifiers of this constraint
|
||||
/// system appearing in the order the identifiers correspond to columns.
|
||||
/// Temporary ones or those that aren't associated to any MLValue are to be
|
||||
/// Temporary ones or those that aren't associated to any Value are to be
|
||||
/// set to None.
|
||||
SmallVector<Optional<MLValue *>, 8> ids;
|
||||
SmallVector<Optional<Value *>, 8> ids;
|
||||
};
|
||||
|
||||
} // end namespace mlir.
|
||||
|
|
|
@ -58,10 +58,10 @@ public:
|
|||
}
|
||||
|
||||
/// Return true if value A properly dominates instruction B.
|
||||
bool properlyDominates(const SSAValue *a, const Instruction *b);
|
||||
bool properlyDominates(const Value *a, const Instruction *b);
|
||||
|
||||
/// Return true if instruction A dominates instruction B.
|
||||
bool dominates(const SSAValue *a, const Instruction *b) {
|
||||
bool dominates(const Value *a, const Instruction *b) {
|
||||
return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b);
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ class AffineCondition;
|
|||
class AffineMap;
|
||||
class IntegerSet;
|
||||
class MLIRContext;
|
||||
class MLValue;
|
||||
class MutableIntegerSet;
|
||||
class FlatAffineConstraints;
|
||||
class HyperRectangleList;
|
||||
class Value;
|
||||
|
||||
/// A list of affine bounds.
|
||||
// Not using a MutableAffineMap here since numSymbols is the same as the
|
||||
|
@ -152,8 +152,8 @@ private:
|
|||
// expressions.
|
||||
std::vector<AffineBoundExprList> upperBounds;
|
||||
|
||||
Optional<SmallVector<MLValue *, 8>> dims = None;
|
||||
Optional<SmallVector<MLValue *, 4>> symbols = None;
|
||||
Optional<SmallVector<Value *, 8>> dims = None;
|
||||
Optional<SmallVector<Value *, 4>> symbols = None;
|
||||
|
||||
/// Number of real dimensions.
|
||||
unsigned numDims;
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
|
@ -33,8 +32,8 @@ class AffineExpr;
|
|||
class AffineMap;
|
||||
class ForStmt;
|
||||
class MemRefType;
|
||||
class MLValue;
|
||||
class OperationStmt;
|
||||
class Value;
|
||||
|
||||
/// Returns the trip count of the loop as an affine expression if the latter is
|
||||
/// expressible as an affine expression, and nullptr otherwise. The trip count
|
||||
|
@ -66,7 +65,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
|||
///
|
||||
/// Returns false in cases with more than one AffineApplyOp, this is
|
||||
/// conservative.
|
||||
bool isAccessInvariant(const MLValue &iv, const MLValue &index);
|
||||
bool isAccessInvariant(const Value &iv, const Value &index);
|
||||
|
||||
/// Given an induction variable `iv` of type ForStmt and `indices` of type
|
||||
/// IndexType, returns the set of `indices` that are independent of `iv`.
|
||||
|
@ -77,9 +76,8 @@ bool isAccessInvariant(const MLValue &iv, const MLValue &index);
|
|||
///
|
||||
/// Returns false in cases with more than one AffineApplyOp, this is
|
||||
/// conservative.
|
||||
llvm::DenseSet<const MLValue *, llvm::DenseMapInfo<const MLValue *>>
|
||||
getInvariantAccesses(const MLValue &iv,
|
||||
llvm::ArrayRef<const MLValue *> indices);
|
||||
llvm::DenseSet<const Value *, llvm::DenseMapInfo<const Value *>>
|
||||
getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
|
||||
|
||||
/// Checks whether the loop is structurally vectorizable; i.e.:
|
||||
/// 1. the loop has proper dependence semantics (parallel, reduction, etc);
|
||||
|
|
|
@ -127,10 +127,10 @@ void getBackwardSlice(
|
|||
/// **includes** the original statement.
|
||||
///
|
||||
/// This allows building a slice (i.e. multi-root DAG where everything
|
||||
/// that is reachable from an SSAValue in forward and backward direction is
|
||||
/// that is reachable from an Value in forward and backward direction is
|
||||
/// contained in the slice).
|
||||
/// This is the abstraction we need to materialize all the instructions for
|
||||
/// supervectorization without worrying about orderings and SSAValue
|
||||
/// supervectorization without worrying about orderings and Value
|
||||
/// replacements.
|
||||
///
|
||||
/// Example starting from any node
|
||||
|
|
|
@ -34,10 +34,10 @@ namespace mlir {
|
|||
|
||||
class FlatAffineConstraints;
|
||||
class ForStmt;
|
||||
class MLValue;
|
||||
class MemRefAccess;
|
||||
class OperationStmt;
|
||||
class Statement;
|
||||
class Value;
|
||||
|
||||
/// Returns true if statement 'a' dominates statement b.
|
||||
bool dominates(const Statement &a, const Statement &b);
|
||||
|
@ -92,7 +92,7 @@ struct MemRefRegion {
|
|||
unsigned getRank() const;
|
||||
|
||||
/// Memref that this region corresponds to.
|
||||
MLValue *memref;
|
||||
Value *memref;
|
||||
|
||||
private:
|
||||
/// Read or write.
|
||||
|
|
|
@ -408,8 +408,8 @@ public:
|
|||
}
|
||||
|
||||
// Creates a for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(Location location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
ForStmt *createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step = 1);
|
||||
|
||||
// Creates a for statement with known (constant) lower and upper bounds.
|
||||
|
@ -417,7 +417,7 @@ public:
|
|||
ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
|
||||
|
||||
/// Creates if statement.
|
||||
IfStmt *createIf(Location location, ArrayRef<MLValue *> operands,
|
||||
IfStmt *createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
|
||||
private:
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
|
||||
namespace mlir {
|
||||
class Builder;
|
||||
class MLValue;
|
||||
|
||||
class BuiltinDialect : public Dialect {
|
||||
public:
|
||||
|
@ -57,7 +56,7 @@ class AffineApplyOp
|
|||
public:
|
||||
/// Builds an affine apply op with the specified map and operands.
|
||||
static void build(Builder *builder, OperationState *result, AffineMap map,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
/// Returns the affine map to be applied by this operation.
|
||||
AffineMap getAffineMap() const {
|
||||
|
@ -101,7 +100,7 @@ public:
|
|||
static StringRef getOperationName() { return "br"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result, BasicBlock *dest,
|
||||
ArrayRef<SSAValue *> operands = {});
|
||||
ArrayRef<Value *> operands = {});
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
|
@ -144,10 +143,9 @@ class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
|
|||
public:
|
||||
static StringRef getOperationName() { return "cond_br"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, BasicBlock *trueDest,
|
||||
ArrayRef<SSAValue *> trueOperands, BasicBlock *falseDest,
|
||||
ArrayRef<SSAValue *> falseOperands);
|
||||
static void build(Builder *builder, OperationState *result, Value *condition,
|
||||
BasicBlock *trueDest, ArrayRef<Value *> trueOperands,
|
||||
BasicBlock *falseDest, ArrayRef<Value *> falseOperands);
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
|
@ -155,8 +153,8 @@ public:
|
|||
bool verify() const;
|
||||
|
||||
// The condition operand is the first operand in the list.
|
||||
SSAValue *getCondition() { return getOperand(0); }
|
||||
const SSAValue *getCondition() const { return getOperand(0); }
|
||||
Value *getCondition() { return getOperand(0); }
|
||||
const Value *getCondition() const { return getOperand(0); }
|
||||
|
||||
/// Return the destination if the condition is true.
|
||||
BasicBlock *getTrueDest() const;
|
||||
|
@ -165,14 +163,14 @@ public:
|
|||
BasicBlock *getFalseDest() const;
|
||||
|
||||
// Accessors for operands to the 'true' destination.
|
||||
SSAValue *getTrueOperand(unsigned idx) {
|
||||
Value *getTrueOperand(unsigned idx) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
return getOperand(getTrueDestOperandIndex() + idx);
|
||||
}
|
||||
const SSAValue *getTrueOperand(unsigned idx) const {
|
||||
const Value *getTrueOperand(unsigned idx) const {
|
||||
return const_cast<CondBranchOp *>(this)->getTrueOperand(idx);
|
||||
}
|
||||
void setTrueOperand(unsigned idx, SSAValue *value) {
|
||||
void setTrueOperand(unsigned idx, Value *value) {
|
||||
assert(idx < getNumTrueOperands());
|
||||
setOperand(getTrueDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
@ -199,14 +197,14 @@ public:
|
|||
void eraseTrueOperand(unsigned index);
|
||||
|
||||
// Accessors for operands to the 'false' destination.
|
||||
SSAValue *getFalseOperand(unsigned idx) {
|
||||
Value *getFalseOperand(unsigned idx) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
return getOperand(getFalseDestOperandIndex() + idx);
|
||||
}
|
||||
const SSAValue *getFalseOperand(unsigned idx) const {
|
||||
const Value *getFalseOperand(unsigned idx) const {
|
||||
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
|
||||
}
|
||||
void setFalseOperand(unsigned idx, SSAValue *value) {
|
||||
void setFalseOperand(unsigned idx, Value *value) {
|
||||
assert(idx < getNumFalseOperands());
|
||||
setOperand(getFalseDestOperandIndex() + idx, value);
|
||||
}
|
||||
|
@ -361,7 +359,7 @@ public:
|
|||
static StringRef getOperationName() { return "return"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
ArrayRef<SSAValue *> results = {});
|
||||
ArrayRef<Value *> results = {});
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
|
@ -380,7 +378,7 @@ void printDimAndSymbolList(Operation::const_operand_iterator begin,
|
|||
|
||||
// Parses dimension and symbol list and returns true if parsing failed.
|
||||
bool parseDimAndSymbolList(OpAsmParser *parser,
|
||||
SmallVector<SSAValue *, 4> &operands,
|
||||
SmallVector<Value *, 4> &operands,
|
||||
unsigned &numDims);
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StmtBlock.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
|
|
@ -1,133 +0,0 @@
|
|||
//===- MLValue.h - MLValue base class and SSA type decls ------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines SSA manipulation implementations for ML functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_MLVALUE_H
|
||||
#define MLIR_IR_MLVALUE_H
|
||||
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
|
||||
namespace mlir {
|
||||
class ForStmt;
|
||||
class MLValue;
|
||||
using MLFunction = Function;
|
||||
class Statement;
|
||||
class StmtBlock;
|
||||
|
||||
/// This enum contains all of the SSA value kinds that are valid in an ML
|
||||
/// function. This should be kept as a proper subtype of SSAValueKind,
|
||||
/// including having all of the values of the enumerators align.
|
||||
enum class MLValueKind {
|
||||
BlockArgument = (int)SSAValueKind::BlockArgument,
|
||||
StmtResult = (int)SSAValueKind::StmtResult,
|
||||
ForStmt = (int)SSAValueKind::ForStmt,
|
||||
};
|
||||
|
||||
/// The operand of ML function statement contains an MLValue.
|
||||
using StmtOperand = IROperandImpl<MLValue, Statement>;
|
||||
|
||||
/// MLValue is the base class for SSA values in ML functions.
|
||||
class MLValue : public SSAValueImpl<StmtOperand, Statement, MLValueKind> {
|
||||
public:
|
||||
/// Returns true if the given MLValue can be used as a dimension id.
|
||||
bool isValidDim() const;
|
||||
|
||||
/// Returns true if the given MLValue can be used as a symbol.
|
||||
bool isValidSymbol() const;
|
||||
|
||||
static bool classof(const SSAValue *value) {
|
||||
switch (value->getKind()) {
|
||||
case SSAValueKind::BlockArgument:
|
||||
case SSAValueKind::StmtResult:
|
||||
case SSAValueKind::ForStmt:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
MLFunction *getFunction();
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
const MLFunction *getFunction() const {
|
||||
return const_cast<MLValue *>(this)->getFunction();
|
||||
}
|
||||
|
||||
protected:
|
||||
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
||||
/// Block arguments are ML Values.
|
||||
class BlockArgument : public MLValue {
|
||||
public:
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::BlockArgument;
|
||||
}
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
MLFunction *getFunction();
|
||||
const MLFunction *getFunction() const {
|
||||
return const_cast<BlockArgument *>(this)->getFunction();
|
||||
}
|
||||
|
||||
StmtBlock *getOwner() { return owner; }
|
||||
const StmtBlock *getOwner() const { return owner; }
|
||||
|
||||
private:
|
||||
friend class StmtBlock; // For access to private constructor.
|
||||
BlockArgument(Type type, StmtBlock *owner)
|
||||
: MLValue(MLValueKind::BlockArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
StmtBlock *const owner;
|
||||
};
|
||||
|
||||
/// This is a value defined by a result of an operation instruction.
|
||||
class StmtResult : public MLValue {
|
||||
public:
|
||||
StmtResult(Type type, OperationStmt *owner)
|
||||
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
|
||||
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::StmtResult;
|
||||
}
|
||||
|
||||
OperationStmt *getOwner() { return owner; }
|
||||
const OperationStmt *getOwner() const { return owner; }
|
||||
|
||||
/// Returns the number of this result.
|
||||
unsigned getResultNumber() const;
|
||||
|
||||
private:
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
OperationStmt *const owner;
|
||||
};
|
||||
|
||||
// TODO(clattner) clean all this up.
|
||||
using CFGValue = MLValue;
|
||||
using BBArgument = BlockArgument;
|
||||
using InstResult = StmtResult;
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
|
@ -27,8 +27,8 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
@ -107,7 +107,7 @@ template <typename OpClass> struct op_matcher {
|
|||
|
||||
/// Entry point for matching a pattern over an SSAValue.
|
||||
template <typename Pattern>
|
||||
inline bool matchPattern(SSAValue *value, const Pattern &pattern) {
|
||||
inline bool matchPattern(Value *value, const Pattern &pattern) {
|
||||
// TODO: handle other cases
|
||||
if (auto *op = value->getDefiningOperation())
|
||||
return const_cast<Pattern &>(pattern).match(op);
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#define MLIR_IR_OPDEFINITION_H
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
@ -78,11 +78,11 @@ public:
|
|||
}
|
||||
|
||||
/// If the OpType operation includes the OneResult trait, then OpPointer can
|
||||
/// be implicitly converted to an SSAValue*. This yields the value of the
|
||||
/// be implicitly converted to an Value*. This yields the value of the
|
||||
/// only result.
|
||||
template <typename SFINAE = OpType>
|
||||
operator typename std::enable_if<IsSingleResult<SFINAE>::value,
|
||||
SSAValue *>::type() {
|
||||
Value *>::type() {
|
||||
return value.getResult();
|
||||
}
|
||||
|
||||
|
@ -114,14 +114,14 @@ public:
|
|||
}
|
||||
|
||||
/// If the OpType operation includes the OneResult trait, then OpPointer can
|
||||
/// be implicitly converted to an const SSAValue*. This yields the value of
|
||||
/// be implicitly converted to an const Value*. This yields the value of
|
||||
/// the only result.
|
||||
template <typename SFINAE = OpType>
|
||||
operator typename std::enable_if<
|
||||
std::is_convertible<
|
||||
SFINAE *,
|
||||
OpTrait::OneResult<typename SFINAE::ConcreteOpType> *>::value,
|
||||
const SSAValue *>::type() const {
|
||||
const Value *>::type() const {
|
||||
return value.getResult();
|
||||
}
|
||||
|
||||
|
@ -346,15 +346,13 @@ private:
|
|||
template <typename ConcreteType>
|
||||
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
|
||||
public:
|
||||
const SSAValue *getOperand() const {
|
||||
const Value *getOperand() const {
|
||||
return this->getOperation()->getOperand(0);
|
||||
}
|
||||
|
||||
SSAValue *getOperand() { return this->getOperation()->getOperand(0); }
|
||||
Value *getOperand() { return this->getOperation()->getOperand(0); }
|
||||
|
||||
void setOperand(SSAValue *value) {
|
||||
this->getOperation()->setOperand(0, value);
|
||||
}
|
||||
void setOperand(Value *value) { this->getOperation()->setOperand(0, value); }
|
||||
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifyOneOperand(op);
|
||||
|
@ -371,15 +369,15 @@ public:
|
|||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
|
||||
public:
|
||||
const SSAValue *getOperand(unsigned i) const {
|
||||
const Value *getOperand(unsigned i) const {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
|
||||
SSAValue *getOperand(unsigned i) {
|
||||
Value *getOperand(unsigned i) {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
|
||||
void setOperand(unsigned i, SSAValue *value) {
|
||||
void setOperand(unsigned i, Value *value) {
|
||||
this->getOperation()->setOperand(i, value);
|
||||
}
|
||||
|
||||
|
@ -402,15 +400,15 @@ public:
|
|||
unsigned getNumOperands() const {
|
||||
return this->getOperation()->getNumOperands();
|
||||
}
|
||||
const SSAValue *getOperand(unsigned i) const {
|
||||
const Value *getOperand(unsigned i) const {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
|
||||
SSAValue *getOperand(unsigned i) {
|
||||
Value *getOperand(unsigned i) {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
|
||||
void setOperand(unsigned i, SSAValue *value) {
|
||||
void setOperand(unsigned i, Value *value) {
|
||||
this->getOperation()->setOperand(i, value);
|
||||
}
|
||||
|
||||
|
@ -453,15 +451,13 @@ public:
|
|||
return this->getOperation()->getNumOperands();
|
||||
}
|
||||
|
||||
const SSAValue *getOperand(unsigned i) const {
|
||||
const Value *getOperand(unsigned i) const {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
|
||||
SSAValue *getOperand(unsigned i) {
|
||||
return this->getOperation()->getOperand(i);
|
||||
}
|
||||
Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
|
||||
|
||||
void setOperand(unsigned i, SSAValue *value) {
|
||||
void setOperand(unsigned i, Value *value) {
|
||||
this->getOperation()->setOperand(i, value);
|
||||
}
|
||||
|
||||
|
@ -503,17 +499,15 @@ public:
|
|||
template <typename ConcreteType>
|
||||
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||
public:
|
||||
SSAValue *getResult() { return this->getOperation()->getResult(0); }
|
||||
const SSAValue *getResult() const {
|
||||
return this->getOperation()->getResult(0);
|
||||
}
|
||||
Value *getResult() { return this->getOperation()->getResult(0); }
|
||||
const Value *getResult() const { return this->getOperation()->getResult(0); }
|
||||
|
||||
Type getType() const { return getResult()->getType(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
void replaceAllUsesWith(SSAValue *newValue) {
|
||||
void replaceAllUsesWith(Value *newValue) {
|
||||
getResult()->replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
|
@ -548,13 +542,11 @@ public:
|
|||
public:
|
||||
static unsigned getNumResults() { return N; }
|
||||
|
||||
const SSAValue *getResult(unsigned i) const {
|
||||
const Value *getResult(unsigned i) const {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
SSAValue *getResult(unsigned i) {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
||||
|
||||
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
|
||||
|
@ -574,13 +566,11 @@ public:
|
|||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> {
|
||||
public:
|
||||
const SSAValue *getResult(unsigned i) const {
|
||||
const Value *getResult(unsigned i) const {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
SSAValue *getResult(unsigned i) {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
||||
|
||||
Type getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
|
||||
|
@ -599,13 +589,13 @@ public:
|
|||
return this->getOperation()->getNumResults();
|
||||
}
|
||||
|
||||
const SSAValue *getResult(unsigned i) const {
|
||||
const Value *getResult(unsigned i) const {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
||||
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
||||
|
||||
void setResult(unsigned i, SSAValue *value) {
|
||||
void setResult(unsigned i, Value *value) {
|
||||
this->getOperation()->setResult(i, value);
|
||||
}
|
||||
|
||||
|
@ -762,10 +752,10 @@ public:
|
|||
return this->getOperation()->setSuccessor(block, index);
|
||||
}
|
||||
|
||||
void addSuccessorOperand(unsigned index, SSAValue *value) {
|
||||
void addSuccessorOperand(unsigned index, Value *value) {
|
||||
return this->getOperation()->addSuccessorOperand(index, value);
|
||||
}
|
||||
void addSuccessorOperands(unsigned index, ArrayRef<SSAValue *> values) {
|
||||
void addSuccessorOperands(unsigned index, ArrayRef<Value *> values) {
|
||||
return this->getOperation()->addSuccessorOperand(index, values);
|
||||
}
|
||||
};
|
||||
|
@ -889,8 +879,8 @@ private:
|
|||
// These functions are out-of-line implementations of the methods in BinaryOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
void buildBinaryOp(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs);
|
||||
void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs);
|
||||
bool parseBinaryOp(OpAsmParser *parser, OperationState *result);
|
||||
void printBinaryOp(const Operation *op, OpAsmPrinter *p);
|
||||
} // namespace impl
|
||||
|
@ -906,8 +896,8 @@ class BinaryOp
|
|||
: public Op<ConcreteType, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
|
||||
OpTrait::SameOperandsAndResultType, Traits...> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs) {
|
||||
static void build(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs) {
|
||||
impl::buildBinaryOp(builder, result, lhs, rhs);
|
||||
}
|
||||
static bool parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
@ -926,7 +916,7 @@ protected:
|
|||
// These functions are out-of-line implementations of the methods in CastOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
namespace impl {
|
||||
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
|
||||
void buildCastOp(Builder *builder, OperationState *result, Value *source,
|
||||
Type destType);
|
||||
bool parseCastOp(OpAsmParser *parser, OperationState *result);
|
||||
void printCastOp(const Operation *op, OpAsmPrinter *p);
|
||||
|
@ -942,7 +932,7 @@ template <typename ConcreteType, template <typename T> class... Traits>
|
|||
class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
|
||||
OpTrait::HasNoSideEffect, Traits...> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *source,
|
||||
static void build(Builder *builder, OperationState *result, Value *source,
|
||||
Type destType) {
|
||||
impl::buildCastOp(builder, result, source, destType);
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ public:
|
|||
virtual raw_ostream &getStream() const = 0;
|
||||
|
||||
/// Print implementations for various things an operation contains.
|
||||
virtual void printOperand(const SSAValue *value) = 0;
|
||||
virtual void printOperand(const Value *value) = 0;
|
||||
|
||||
/// Print a comma separated list of operands.
|
||||
template <typename ContainerType>
|
||||
|
@ -95,7 +95,7 @@ private:
|
|||
};
|
||||
|
||||
// Make the implementations convenient to use.
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) {
|
||||
p.printOperand(&value);
|
||||
return p;
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) {
|
|||
// even if it isn't exactly one of them. For example, we want to print
|
||||
// FunctionType with the Type& version above, not have it match this.
|
||||
template <typename T, typename std::enable_if<
|
||||
!std::is_convertible<T &, SSAValue &>::value &&
|
||||
!std::is_convertible<T &, Value &>::value &&
|
||||
!std::is_convertible<T &, Type &>::value &&
|
||||
!std::is_convertible<T &, Attribute &>::value &&
|
||||
!std::is_convertible<T &, AffineMap &>::value,
|
||||
|
@ -264,9 +264,8 @@ public:
|
|||
virtual bool parseOperand(OperandType &result) = 0;
|
||||
|
||||
/// Parse a single operation successor and it's operand list.
|
||||
virtual bool
|
||||
parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<SSAValue *> &operands) = 0;
|
||||
virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<Value *> &operands) = 0;
|
||||
|
||||
/// These are the supported delimiters around operand lists, used by
|
||||
/// parseOperandList.
|
||||
|
@ -311,13 +310,13 @@ public:
|
|||
/// Resolve an operand to an SSA value, emitting an error and returning true
|
||||
/// on failure.
|
||||
virtual bool resolveOperand(const OperandType &operand, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) = 0;
|
||||
SmallVectorImpl<Value *> &result) = 0;
|
||||
|
||||
/// Resolve a list of operands to SSA values, emitting an error and returning
|
||||
/// true on failure, or appending the results to the list on success.
|
||||
/// This method should be used when all operands have the same type.
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) {
|
||||
SmallVectorImpl<Value *> &result) {
|
||||
for (auto elt : operands)
|
||||
if (resolveOperand(elt, type, result))
|
||||
return true;
|
||||
|
@ -329,7 +328,7 @@ public:
|
|||
/// to the list on success.
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands,
|
||||
ArrayRef<Type> types, llvm::SMLoc loc,
|
||||
SmallVectorImpl<SSAValue *> &result) {
|
||||
SmallVectorImpl<Value *> &result) {
|
||||
if (operands.size() != types.size())
|
||||
return emitError(loc, Twine(operands.size()) +
|
||||
" operands present, but expected " +
|
||||
|
|
|
@ -64,21 +64,20 @@ public:
|
|||
/// Return the number of operands this operation has.
|
||||
unsigned getNumOperands() const;
|
||||
|
||||
SSAValue *getOperand(unsigned idx);
|
||||
const SSAValue *getOperand(unsigned idx) const {
|
||||
Value *getOperand(unsigned idx);
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return const_cast<Operation *>(this)->getOperand(idx);
|
||||
}
|
||||
void setOperand(unsigned idx, SSAValue *value);
|
||||
void setOperand(unsigned idx, Value *value);
|
||||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = OperandIterator<Operation, SSAValue>;
|
||||
using operand_iterator = OperandIterator<Operation, Value>;
|
||||
operand_iterator operand_begin();
|
||||
operand_iterator operand_end();
|
||||
llvm::iterator_range<operand_iterator> getOperands();
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator =
|
||||
OperandIterator<const Operation, const SSAValue>;
|
||||
using const_operand_iterator = OperandIterator<const Operation, const Value>;
|
||||
const_operand_iterator operand_begin() const;
|
||||
const_operand_iterator operand_end() const;
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const;
|
||||
|
@ -87,26 +86,25 @@ public:
|
|||
unsigned getNumResults() const;
|
||||
|
||||
/// Return the indicated result.
|
||||
SSAValue *getResult(unsigned idx);
|
||||
const SSAValue *getResult(unsigned idx) const {
|
||||
Value *getResult(unsigned idx);
|
||||
const Value *getResult(unsigned idx) const {
|
||||
return const_cast<Operation *>(this)->getResult(idx);
|
||||
}
|
||||
|
||||
// Support non-const result iteration.
|
||||
using result_iterator = ResultIterator<Operation, SSAValue>;
|
||||
using result_iterator = ResultIterator<Operation, Value>;
|
||||
result_iterator result_begin();
|
||||
result_iterator result_end();
|
||||
llvm::iterator_range<result_iterator> getResults();
|
||||
|
||||
// Support const result iteration.
|
||||
using const_result_iterator = ResultIterator<const Operation, const SSAValue>;
|
||||
using const_result_iterator = ResultIterator<const Operation, const Value>;
|
||||
const_result_iterator result_begin() const;
|
||||
const_result_iterator result_end() const;
|
||||
llvm::iterator_range<const_result_iterator> getResults() const;
|
||||
|
||||
// Support for result type iteration.
|
||||
using result_type_iterator =
|
||||
ResultTypeIterator<const Operation, const SSAValue>;
|
||||
using result_type_iterator = ResultTypeIterator<const Operation, const Value>;
|
||||
result_type_iterator result_type_begin() const;
|
||||
result_type_iterator result_type_end() const;
|
||||
llvm::iterator_range<result_type_iterator> getResultTypes() const;
|
||||
|
|
|
@ -40,8 +40,8 @@ class OpAsmPrinter;
|
|||
class Pattern;
|
||||
class RewritePattern;
|
||||
class StmtBlock;
|
||||
class SSAValue;
|
||||
class Type;
|
||||
class Value;
|
||||
using BasicBlock = StmtBlock;
|
||||
|
||||
/// This is a vector that owns the patterns inside of it.
|
||||
|
@ -203,7 +203,7 @@ struct OperationState {
|
|||
MLIRContext *const context;
|
||||
Location location;
|
||||
OperationName name;
|
||||
SmallVector<SSAValue *, 4> operands;
|
||||
SmallVector<Value *, 4> operands;
|
||||
/// Types of the results of this operation.
|
||||
SmallVector<Type, 4> types;
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
|
@ -218,7 +218,7 @@ public:
|
|||
: context(context), location(location), name(name) {}
|
||||
|
||||
OperationState(MLIRContext *context, Location location, StringRef name,
|
||||
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
|
||||
ArrayRef<Value *> operands, ArrayRef<Type> types,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
ArrayRef<StmtBlock *> successors = {})
|
||||
: context(context), location(location), name(name, context),
|
||||
|
@ -227,7 +227,7 @@ public:
|
|||
attributes(attributes.begin(), attributes.end()),
|
||||
successors(successors.begin(), successors.end()) {}
|
||||
|
||||
void addOperands(ArrayRef<SSAValue *> newOperands) {
|
||||
void addOperands(ArrayRef<Value *> newOperands) {
|
||||
assert(successors.empty() &&
|
||||
"Non successor operands should be added first.");
|
||||
operands.append(newOperands.begin(), newOperands.end());
|
||||
|
@ -247,7 +247,7 @@ public:
|
|||
attributes.push_back({name, attr});
|
||||
}
|
||||
|
||||
void addSuccessor(StmtBlock *successor, ArrayRef<SSAValue *> succOperands) {
|
||||
void addSuccessor(StmtBlock *successor, ArrayRef<Value *> succOperands) {
|
||||
successors.push_back(successor);
|
||||
// Insert a sentinal operand to mark a barrier between successor operands.
|
||||
operands.push_back(nullptr);
|
||||
|
|
|
@ -222,8 +222,8 @@ public:
|
|||
/// clients can specify a list of other nodes that this replacement may make
|
||||
/// (perhaps transitively) dead. If any of those values are dead, this will
|
||||
/// remove them as well.
|
||||
void replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead = {});
|
||||
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead = {});
|
||||
|
||||
/// Replaces the result op with a new op that is created without verification.
|
||||
/// The result values of the two ops must be the same types.
|
||||
|
@ -237,8 +237,7 @@ public:
|
|||
/// The result values of the two ops must be the same types. This allows
|
||||
/// specifying a list of ops that may be removed if dead.
|
||||
template <typename OpTy, typename... Args>
|
||||
void replaceOpWithNewOp(Operation *op,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead,
|
||||
void replaceOpWithNewOp(Operation *op, ArrayRef<Value *> valuesToRemoveIfDead,
|
||||
Args... args) {
|
||||
auto newOp = create<OpTy>(op->getLoc(), args...);
|
||||
replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(),
|
||||
|
@ -254,7 +253,7 @@ public:
|
|||
/// rewriter should remove if they are dead at this point.
|
||||
///
|
||||
void updatedRootInPlace(Operation *op,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead = {});
|
||||
ArrayRef<Value *> valuesToRemoveIfDead = {});
|
||||
|
||||
protected:
|
||||
PatternRewriter(MLIRContext *context) : Builder(context) {}
|
||||
|
@ -284,9 +283,8 @@ protected:
|
|||
private:
|
||||
/// op and newOp are known to have the same number of results, replace the
|
||||
/// uses of op with uses of newOp
|
||||
void
|
||||
replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead);
|
||||
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,154 +0,0 @@
|
|||
//===- SSAValue.h - Base of the value hierarchy -----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines generic SSAValue type and manipulation utilities.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_SSAVALUE_H
|
||||
#define MLIR_IR_SSAVALUE_H
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/UseDefLists.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
class OperationStmt;
|
||||
class Operation;
|
||||
class Statement;
|
||||
using Instruction = Statement;
|
||||
using OperationInst = OperationStmt;
|
||||
|
||||
/// This enumerates all of the SSA value kinds in the MLIR system.
|
||||
enum class SSAValueKind {
|
||||
BlockArgument, // Block argument
|
||||
StmtResult, // statement result
|
||||
ForStmt, // for statement induction variable
|
||||
};
|
||||
|
||||
/// This is the common base class for all values in the MLIR system,
|
||||
/// representing a computable value that has a type and a set of users.
|
||||
///
|
||||
class SSAValue : public IRObjectWithUseList {
|
||||
public:
|
||||
~SSAValue() {}
|
||||
|
||||
SSAValueKind getKind() const { return typeAndKind.getInt(); }
|
||||
|
||||
Type getType() const { return typeAndKind.getPointer(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
void replaceAllUsesWith(SSAValue *newValue) {
|
||||
IRObjectWithUseList::replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
Function *getFunction();
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
const Function *getFunction() const {
|
||||
return const_cast<SSAValue *>(this)->getFunction();
|
||||
}
|
||||
|
||||
/// If this value is the result of an Instruction, return the instruction
|
||||
/// that defines it.
|
||||
OperationInst *getDefiningInst();
|
||||
const OperationInst *getDefiningInst() const {
|
||||
return const_cast<SSAValue *>(this)->getDefiningInst();
|
||||
}
|
||||
|
||||
/// If this value is the result of an OperationStmt, return the statement
|
||||
/// that defines it.
|
||||
OperationStmt *getDefiningStmt();
|
||||
const OperationStmt *getDefiningStmt() const {
|
||||
return const_cast<SSAValue *>(this)->getDefiningStmt();
|
||||
}
|
||||
|
||||
/// If this value is the result of an Operation, return the operation that
|
||||
/// defines it.
|
||||
Operation *getDefiningOperation();
|
||||
const Operation *getDefiningOperation() const {
|
||||
return const_cast<SSAValue *>(this)->getDefiningOperation();
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
protected:
|
||||
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
|
||||
|
||||
private:
|
||||
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
|
||||
value.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
/// This template unifies the implementation logic for CFGValue and MLValue
|
||||
/// while providing more type-specific APIs when walking use lists etc.
|
||||
///
|
||||
/// IROperandTy is the concrete instance of IROperand to use (including
|
||||
/// substituted template arguments).
|
||||
/// IROwnerTy is the type of the owner of an IROperandTy type.
|
||||
/// KindTy is the enum 'kind' discriminator that subclasses want to use.
|
||||
///
|
||||
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
|
||||
class SSAValueImpl : public SSAValue {
|
||||
public:
|
||||
// Provide more specific implementations of the base class functionality.
|
||||
KindTy getKind() const { return (KindTy)SSAValue::getKind(); }
|
||||
|
||||
using use_iterator = SSAValueUseIterator<IROperandTy, IROwnerTy>;
|
||||
using use_range = llvm::iterator_range<use_iterator>;
|
||||
|
||||
inline use_iterator use_begin() const;
|
||||
inline use_iterator use_end() const;
|
||||
|
||||
/// Returns a range of all uses, which is useful for iterating over all uses.
|
||||
inline use_range getUses() const;
|
||||
|
||||
protected:
|
||||
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
|
||||
};
|
||||
|
||||
// Utility functions for iterating through SSAValue uses.
|
||||
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
|
||||
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_begin() const
|
||||
-> use_iterator {
|
||||
return use_iterator((IROperandTy *)getFirstUse());
|
||||
}
|
||||
|
||||
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
|
||||
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_end() const
|
||||
-> use_iterator {
|
||||
return use_iterator(nullptr);
|
||||
}
|
||||
|
||||
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
|
||||
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::getUses() const
|
||||
-> llvm::iterator_range<use_iterator> {
|
||||
return {use_begin(), use_end()};
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
|
@ -22,8 +22,8 @@
|
|||
#ifndef MLIR_IR_STATEMENT_H
|
||||
#define MLIR_IR_STATEMENT_H
|
||||
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ilist.h"
|
||||
#include "llvm/ADT/ilist_node.h"
|
||||
|
@ -84,8 +84,8 @@ public:
|
|||
|
||||
// This is a verbose type used by the clone method below.
|
||||
using OperandMapTy =
|
||||
DenseMap<const MLValue *, MLValue *, llvm::DenseMapInfo<const MLValue *>,
|
||||
llvm::detail::DenseMapPair<const MLValue *, MLValue *>>;
|
||||
DenseMap<const Value *, Value *, llvm::DenseMapInfo<const Value *>,
|
||||
llvm::detail::DenseMapPair<const Value *, Value *>>;
|
||||
|
||||
/// Create a deep copy of this statement, remapping any operands that use
|
||||
/// values outside of the statement using the map that is provided (leaving
|
||||
|
@ -136,12 +136,12 @@ public:
|
|||
|
||||
unsigned getNumOperands() const;
|
||||
|
||||
MLValue *getOperand(unsigned idx);
|
||||
const MLValue *getOperand(unsigned idx) const;
|
||||
void setOperand(unsigned idx, MLValue *value);
|
||||
Value *getOperand(unsigned idx);
|
||||
const Value *getOperand(unsigned idx) const;
|
||||
void setOperand(unsigned idx, Value *value);
|
||||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = OperandIterator<Statement, MLValue>;
|
||||
using operand_iterator = OperandIterator<Statement, Value>;
|
||||
|
||||
operand_iterator operand_begin() { return operand_iterator(this, 0); }
|
||||
|
||||
|
@ -149,14 +149,13 @@ public:
|
|||
return operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
/// Returns an iterator on the underlying MLValue's (MLValue *).
|
||||
/// Returns an iterator on the underlying Value's (Value *).
|
||||
llvm::iterator_range<operand_iterator> getOperands() {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator =
|
||||
OperandIterator<const Statement, const MLValue>;
|
||||
using const_operand_iterator = OperandIterator<const Statement, const Value>;
|
||||
|
||||
const_operand_iterator operand_begin() const {
|
||||
return const_operand_iterator(this, 0);
|
||||
|
@ -166,7 +165,7 @@ public:
|
|||
return const_operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
/// Returns a const iterator on the underlying MLValue's (MLValue *).
|
||||
/// Returns a const iterator on the underlying Value's (Value *).
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ class OperationStmt final
|
|||
public:
|
||||
/// Create a new OperationStmt with the specific fields.
|
||||
static OperationStmt *
|
||||
create(Location location, OperationName name, ArrayRef<MLValue *> operands,
|
||||
create(Location location, OperationName name, ArrayRef<Value *> operands,
|
||||
ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
|
||||
ArrayRef<StmtBlock *> successors, MLIRContext *context);
|
||||
|
||||
|
@ -69,16 +69,16 @@ public:
|
|||
|
||||
unsigned getNumOperands() const { return numOperands; }
|
||||
|
||||
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const MLValue *getOperand(unsigned idx) const {
|
||||
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
void setOperand(unsigned idx, MLValue *value) {
|
||||
void setOperand(unsigned idx, Value *value) {
|
||||
return getStmtOperand(idx).set(value);
|
||||
}
|
||||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = OperandIterator<OperationStmt, MLValue>;
|
||||
using operand_iterator = OperandIterator<OperationStmt, Value>;
|
||||
|
||||
operand_iterator operand_begin() { return operand_iterator(this, 0); }
|
||||
|
||||
|
@ -86,14 +86,14 @@ public:
|
|||
return operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
/// Returns an iterator on the underlying MLValue's (MLValue *).
|
||||
/// Returns an iterator on the underlying Value's (Value *).
|
||||
llvm::iterator_range<operand_iterator> getOperands() {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator =
|
||||
OperandIterator<const OperationStmt, const MLValue>;
|
||||
OperandIterator<const OperationStmt, const Value>;
|
||||
|
||||
const_operand_iterator operand_begin() const {
|
||||
return const_operand_iterator(this, 0);
|
||||
|
@ -103,7 +103,7 @@ public:
|
|||
return const_operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
/// Returns a const iterator on the underlying MLValue's (MLValue *).
|
||||
/// Returns a const iterator on the underlying Value's (Value *).
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
@ -126,11 +126,11 @@ public:
|
|||
|
||||
unsigned getNumResults() const { return numResults; }
|
||||
|
||||
MLValue *getResult(unsigned idx) { return &getStmtResult(idx); }
|
||||
const MLValue *getResult(unsigned idx) const { return &getStmtResult(idx); }
|
||||
Value *getResult(unsigned idx) { return &getStmtResult(idx); }
|
||||
const Value *getResult(unsigned idx) const { return &getStmtResult(idx); }
|
||||
|
||||
// Support non-const result iteration.
|
||||
using result_iterator = ResultIterator<OperationStmt, MLValue>;
|
||||
using result_iterator = ResultIterator<OperationStmt, Value>;
|
||||
result_iterator result_begin() { return result_iterator(this, 0); }
|
||||
result_iterator result_end() {
|
||||
return result_iterator(this, getNumResults());
|
||||
|
@ -141,7 +141,7 @@ public:
|
|||
|
||||
// Support const result iteration.
|
||||
using const_result_iterator =
|
||||
ResultIterator<const OperationStmt, const MLValue>;
|
||||
ResultIterator<const OperationStmt, const Value>;
|
||||
const_result_iterator result_begin() const {
|
||||
return const_result_iterator(this, 0);
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ public:
|
|||
|
||||
// Support result type iteration.
|
||||
using result_type_iterator =
|
||||
ResultTypeIterator<const OperationStmt, const MLValue>;
|
||||
ResultTypeIterator<const OperationStmt, const Value>;
|
||||
result_type_iterator result_type_begin() const {
|
||||
return result_type_iterator(this, 0);
|
||||
}
|
||||
|
@ -290,15 +290,15 @@ private:
|
|||
};
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public MLValue {
|
||||
class ForStmt : public Statement, public Value {
|
||||
public:
|
||||
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
static ForStmt *create(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step);
|
||||
|
||||
~ForStmt() {
|
||||
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
|
||||
// since child statements need to be destroyed before the MLValue that this
|
||||
// since child statements need to be destroyed before the Value that this
|
||||
// for stmt represents is destroyed. Affine maps are immortal objects and
|
||||
// don't need to be deleted.
|
||||
getBody()->clear();
|
||||
|
@ -308,8 +308,8 @@ public:
|
|||
using Statement::getFunction;
|
||||
|
||||
/// Operand iterators.
|
||||
using operand_iterator = OperandIterator<ForStmt, MLValue>;
|
||||
using const_operand_iterator = OperandIterator<const ForStmt, const MLValue>;
|
||||
using operand_iterator = OperandIterator<ForStmt, Value>;
|
||||
using const_operand_iterator = OperandIterator<const ForStmt, const Value>;
|
||||
|
||||
/// Operand iterator range.
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
|
@ -340,9 +340,9 @@ public:
|
|||
AffineMap getUpperBoundMap() const { return ubMap; }
|
||||
|
||||
/// Set lower bound.
|
||||
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap map);
|
||||
void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
|
||||
/// Set upper bound.
|
||||
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap map);
|
||||
void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
|
||||
|
||||
/// Set the lower bound map without changing operands.
|
||||
void setLowerBoundMap(AffineMap map);
|
||||
|
@ -385,11 +385,11 @@ public:
|
|||
|
||||
unsigned getNumOperands() const { return operands.size(); }
|
||||
|
||||
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const MLValue *getOperand(unsigned idx) const {
|
||||
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
void setOperand(unsigned idx, MLValue *value) {
|
||||
void setOperand(unsigned idx, Value *value) {
|
||||
getStmtOperand(idx).set(value);
|
||||
}
|
||||
|
||||
|
@ -439,10 +439,10 @@ public:
|
|||
}
|
||||
|
||||
// For statement represents implicitly represents induction variable by
|
||||
// inheriting from MLValue class. Whenever you need to refer to the loop
|
||||
// inheriting from Value class. Whenever you need to refer to the loop
|
||||
// induction variable, just use the for statement itself.
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::ForStmt;
|
||||
static bool classof(const Value *value) {
|
||||
return value->getKind() == Value::Kind::ForStmt;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -475,7 +475,7 @@ public:
|
|||
AffineMap getMap() const { return map; }
|
||||
|
||||
unsigned getNumOperands() const { return opEnd - opStart; }
|
||||
const MLValue *getOperand(unsigned idx) const {
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return stmt.getOperand(opStart + idx);
|
||||
}
|
||||
const StmtOperand &getStmtOperand(unsigned idx) const {
|
||||
|
@ -486,15 +486,15 @@ public:
|
|||
using operand_range = ForStmt::operand_range;
|
||||
|
||||
operand_iterator operand_begin() const {
|
||||
// These are iterators over MLValue *. Not casting away const'ness would
|
||||
// require the caller to use const MLValue *.
|
||||
// These are iterators over Value *. Not casting away const'ness would
|
||||
// require the caller to use const Value *.
|
||||
return operand_iterator(const_cast<ForStmt *>(&stmt), opStart);
|
||||
}
|
||||
operand_iterator operand_end() const {
|
||||
return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd);
|
||||
}
|
||||
|
||||
/// Returns an iterator on the underlying MLValue's (MLValue *).
|
||||
/// Returns an iterator on the underlying Value's (Value *).
|
||||
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
|
||||
ArrayRef<StmtOperand> getStmtOperands() const {
|
||||
auto ops = stmt.getStmtOperands();
|
||||
|
@ -520,7 +520,7 @@ private:
|
|||
/// If statement restricts execution to a subset of the loop iteration space.
|
||||
class IfStmt : public Statement {
|
||||
public:
|
||||
static IfStmt *create(Location location, ArrayRef<MLValue *> operands,
|
||||
static IfStmt *create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
~IfStmt();
|
||||
|
||||
|
@ -556,8 +556,8 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Operand iterators.
|
||||
using operand_iterator = OperandIterator<IfStmt, MLValue>;
|
||||
using const_operand_iterator = OperandIterator<const IfStmt, const MLValue>;
|
||||
using operand_iterator = OperandIterator<IfStmt, Value>;
|
||||
using const_operand_iterator = OperandIterator<const IfStmt, const Value>;
|
||||
|
||||
/// Operand iterator range.
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
|
@ -565,11 +565,11 @@ public:
|
|||
|
||||
unsigned getNumOperands() const { return operands.size(); }
|
||||
|
||||
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const MLValue *getOperand(unsigned idx) const {
|
||||
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
void setOperand(unsigned idx, MLValue *value) {
|
||||
void setOperand(unsigned idx, Value *value) {
|
||||
getStmtOperand(idx).set(value);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
namespace mlir {
|
||||
class IfStmt;
|
||||
class MLValue;
|
||||
class StmtBlockList;
|
||||
using CFGFunction = Function;
|
||||
using MLFunction = Function;
|
||||
|
@ -412,7 +411,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
using BBUseIterator = SSAValueUseIterator<StmtBlockOperand, OperationStmt>;
|
||||
using BBUseIterator = ValueUseIterator<StmtBlockOperand, OperationStmt>;
|
||||
BBUseIterator bbUseIterator;
|
||||
};
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace mlir {
|
|||
|
||||
class IROperand;
|
||||
class IROperandOwner;
|
||||
template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
|
||||
template <typename OperandType, typename OwnerType> class ValueUseIterator;
|
||||
|
||||
class IRObjectWithUseList {
|
||||
public:
|
||||
|
@ -44,7 +44,7 @@ public:
|
|||
/// Returns true if this value has exactly one use.
|
||||
inline bool hasOneUse() const;
|
||||
|
||||
using use_iterator = SSAValueUseIterator<IROperand, IROperandOwner>;
|
||||
using use_iterator = ValueUseIterator<IROperand, IROperandOwner>;
|
||||
using use_range = llvm::iterator_range<use_iterator>;
|
||||
|
||||
inline use_iterator use_begin() const;
|
||||
|
@ -228,33 +228,33 @@ public:
|
|||
|
||||
/// An iterator over all uses of a ValueBase.
|
||||
template <typename OperandType, typename OwnerType>
|
||||
class SSAValueUseIterator
|
||||
class ValueUseIterator
|
||||
: public std::iterator<std::forward_iterator_tag, IROperand> {
|
||||
public:
|
||||
SSAValueUseIterator() = default;
|
||||
explicit SSAValueUseIterator(OperandType *current) : current(current) {}
|
||||
ValueUseIterator() = default;
|
||||
explicit ValueUseIterator(OperandType *current) : current(current) {}
|
||||
OperandType *operator->() const { return current; }
|
||||
OperandType &operator*() const { return *current; }
|
||||
|
||||
OwnerType *getUser() const { return current->getOwner(); }
|
||||
|
||||
SSAValueUseIterator &operator++() {
|
||||
ValueUseIterator &operator++() {
|
||||
assert(current && "incrementing past end()!");
|
||||
current = (OperandType *)current->getNextOperandUsingThisValue();
|
||||
return *this;
|
||||
}
|
||||
|
||||
SSAValueUseIterator operator++(int unused) {
|
||||
SSAValueUseIterator copy = *this;
|
||||
ValueUseIterator operator++(int unused) {
|
||||
ValueUseIterator copy = *this;
|
||||
++*this;
|
||||
return copy;
|
||||
}
|
||||
|
||||
friend bool operator==(SSAValueUseIterator lhs, SSAValueUseIterator rhs) {
|
||||
friend bool operator==(ValueUseIterator lhs, ValueUseIterator rhs) {
|
||||
return lhs.current == rhs.current;
|
||||
}
|
||||
|
||||
friend bool operator!=(SSAValueUseIterator lhs, SSAValueUseIterator rhs) {
|
||||
friend bool operator!=(ValueUseIterator lhs, ValueUseIterator rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines generic Value type and manipulation utilities.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_VALUE_H
|
||||
#define MLIR_IR_VALUE_H
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/UseDefLists.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
class Function;
|
||||
class OperationStmt;
|
||||
class Operation;
|
||||
class Statement;
|
||||
class StmtBlock;
|
||||
class Value;
|
||||
using Instruction = Statement;
|
||||
using OperationInst = OperationStmt;
|
||||
|
||||
/// The operand of ML function statement contains a Value.
|
||||
using StmtOperand = IROperandImpl<Value, Statement>;
|
||||
|
||||
/// This is the common base class for all values in the MLIR system,
|
||||
/// representing a computable value that has a type and a set of users.
|
||||
///
|
||||
class Value : public IRObjectWithUseList {
|
||||
public:
|
||||
/// This enumerates all of the SSA value kinds in the MLIR system.
|
||||
enum class Kind {
|
||||
BlockArgument, // block argument
|
||||
StmtResult, // statement result
|
||||
ForStmt, // for statement induction variable
|
||||
};
|
||||
|
||||
~Value() {}
|
||||
|
||||
Kind getKind() const { return typeAndKind.getInt(); }
|
||||
|
||||
Type getType() const { return typeAndKind.getPointer(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
void replaceAllUsesWith(Value *newValue) {
|
||||
IRObjectWithUseList::replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
/// TODO: move isValidDim/isValidSymbol to a utility library specific to the
|
||||
/// polyhedral operations.
|
||||
|
||||
/// Returns true if the given Value can be used as a dimension id.
|
||||
bool isValidDim() const;
|
||||
|
||||
/// Returns true if the given Value can be used as a symbol.
|
||||
bool isValidSymbol() const;
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
Function *getFunction();
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
const Function *getFunction() const {
|
||||
return const_cast<Value *>(this)->getFunction();
|
||||
}
|
||||
|
||||
/// If this value is the result of an Instruction, return the instruction
|
||||
/// that defines it.
|
||||
OperationInst *getDefiningInst();
|
||||
const OperationInst *getDefiningInst() const {
|
||||
return const_cast<Value *>(this)->getDefiningInst();
|
||||
}
|
||||
|
||||
/// If this value is the result of an OperationStmt, return the statement
|
||||
/// that defines it.
|
||||
OperationStmt *getDefiningStmt();
|
||||
const OperationStmt *getDefiningStmt() const {
|
||||
return const_cast<Value *>(this)->getDefiningStmt();
|
||||
}
|
||||
|
||||
/// If this value is the result of an Operation, return the operation that
|
||||
/// defines it.
|
||||
Operation *getDefiningOperation();
|
||||
const Operation *getDefiningOperation() const {
|
||||
return const_cast<Value *>(this)->getDefiningOperation();
|
||||
}
|
||||
|
||||
using use_iterator = ValueUseIterator<StmtOperand, Statement>;
|
||||
using use_range = llvm::iterator_range<use_iterator>;
|
||||
|
||||
inline use_iterator use_begin() const;
|
||||
inline use_iterator use_end() const;
|
||||
|
||||
/// Returns a range of all uses, which is useful for iterating over all uses.
|
||||
inline use_range getUses() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
protected:
|
||||
Value(Kind kind, Type type) : typeAndKind(type, kind) {}
|
||||
|
||||
private:
|
||||
const llvm::PointerIntPair<Type, 3, Kind> typeAndKind;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const Value &value) {
|
||||
value.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
// Utility functions for iterating through Value uses.
|
||||
inline auto Value::use_begin() const -> use_iterator {
|
||||
return use_iterator((StmtOperand *)getFirstUse());
|
||||
}
|
||||
|
||||
inline auto Value::use_end() const -> use_iterator {
|
||||
return use_iterator(nullptr);
|
||||
}
|
||||
|
||||
inline auto Value::getUses() const -> llvm::iterator_range<use_iterator> {
|
||||
return {use_begin(), use_end()};
|
||||
}
|
||||
|
||||
/// Block arguments are values.
|
||||
class BlockArgument : public Value {
|
||||
public:
|
||||
static bool classof(const Value *value) {
|
||||
return value->getKind() == Kind::BlockArgument;
|
||||
}
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
Function *getFunction();
|
||||
const Function *getFunction() const {
|
||||
return const_cast<BlockArgument *>(this)->getFunction();
|
||||
}
|
||||
|
||||
StmtBlock *getOwner() { return owner; }
|
||||
const StmtBlock *getOwner() const { return owner; }
|
||||
|
||||
private:
|
||||
friend class StmtBlock; // For access to private constructor.
|
||||
BlockArgument(Type type, StmtBlock *owner)
|
||||
: Value(Value::Kind::BlockArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
StmtBlock *const owner;
|
||||
};
|
||||
|
||||
/// This is a value defined by a result of an operation instruction.
|
||||
class StmtResult : public Value {
|
||||
public:
|
||||
StmtResult(Type type, OperationStmt *owner)
|
||||
: Value(Value::Kind::StmtResult, type), owner(owner) {}
|
||||
|
||||
static bool classof(const Value *value) {
|
||||
return value->getKind() == Kind::StmtResult;
|
||||
}
|
||||
|
||||
OperationStmt *getOwner() { return owner; }
|
||||
const OperationStmt *getOwner() const { return owner; }
|
||||
|
||||
/// Returns the number of this result.
|
||||
unsigned getResultNumber() const;
|
||||
|
||||
private:
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
OperationStmt *const owner;
|
||||
};
|
||||
|
||||
// TODO(clattner) clean all this up.
|
||||
using BBArgument = BlockArgument;
|
||||
using InstResult = StmtResult;
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
|
@ -157,14 +157,14 @@ class Op<string mnemonic, list<OpProperty> props = []> {
|
|||
//
|
||||
// static void build(Builder* builder, OperationState* result,
|
||||
// Type resultType0, Type resultType1, ...,
|
||||
// SSAValue* arg0, SSAValue* arg1, ...,
|
||||
// Value arg0, Value arg1, ...,
|
||||
// Attribute <attr0-name>, Attribute <attr1-name>, ...);
|
||||
//
|
||||
// * where the attributes follow the same declaration order as in the op.
|
||||
//
|
||||
// static void build(Builder* builder, OperationState* result,
|
||||
// ArrayRef<Type> resultTypes,
|
||||
// ArrayRef<SSAValue*> args,
|
||||
// ArrayRef<Value> args,
|
||||
// ArrayRef<NamedAttribute> attributes);
|
||||
code builder = ?;
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
namespace mlir {
|
||||
class AffineMap;
|
||||
class Builder;
|
||||
class MLValue;
|
||||
|
||||
class StandardOpsDialect : public Dialect {
|
||||
public:
|
||||
|
@ -48,8 +47,8 @@ class AddFOp
|
|||
: public BinaryOp<AddFOp, OpTrait::ResultsAreFloatLike,
|
||||
OpTrait::IsCommutative, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs);
|
||||
static void build(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs);
|
||||
|
||||
static StringRef getOperationName() { return "addf"; }
|
||||
|
||||
|
@ -116,7 +115,7 @@ public:
|
|||
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
|
||||
MemRefType memrefType, ArrayRef<Value *> operands = {});
|
||||
bool verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
@ -140,7 +139,7 @@ public:
|
|||
static StringRef getOperationName() { return "call"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
Function *getCallee() const {
|
||||
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
|
@ -169,11 +168,11 @@ class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
|
|||
public:
|
||||
static StringRef getOperationName() { return "call_indirect"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
static void build(Builder *builder, OperationState *result, Value *callee,
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
const SSAValue *getCallee() const { return getOperand(0); }
|
||||
SSAValue *getCallee() { return getOperand(0); }
|
||||
const Value *getCallee() const { return getOperand(0); }
|
||||
Value *getCallee() { return getOperand(0); }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
|
@ -240,7 +239,7 @@ public:
|
|||
static CmpIPredicate getPredicateByName(StringRef name);
|
||||
|
||||
static void build(Builder *builder, OperationState *result, CmpIPredicate,
|
||||
SSAValue *lhs, SSAValue *rhs);
|
||||
Value *lhs, Value *rhs);
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
|
@ -263,14 +262,14 @@ private:
|
|||
class DeallocOp
|
||||
: public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
|
||||
public:
|
||||
SSAValue *getMemRef() { return getOperand(); }
|
||||
const SSAValue *getMemRef() const { return getOperand(); }
|
||||
void setMemRef(SSAValue *value) { setOperand(value); }
|
||||
Value *getMemRef() { return getOperand(); }
|
||||
const Value *getMemRef() const { return getOperand(); }
|
||||
void setMemRef(Value *value) { setOperand(value); }
|
||||
|
||||
static StringRef getOperationName() { return "dealloc"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *memref);
|
||||
static void build(Builder *builder, OperationState *result, Value *memref);
|
||||
bool verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
@ -292,7 +291,7 @@ class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
|
|||
OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *memrefOrTensor, unsigned index);
|
||||
Value *memrefOrTensor, unsigned index);
|
||||
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
@ -354,15 +353,15 @@ private:
|
|||
class DmaStartOp
|
||||
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
|
||||
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
|
||||
SSAValue *numElements, SSAValue *tagMemRef,
|
||||
ArrayRef<SSAValue *> tagIndices, SSAValue *stride = nullptr,
|
||||
SSAValue *elementsPerStride = nullptr);
|
||||
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
|
||||
ArrayRef<Value *> srcIndices, Value *destMemRef,
|
||||
ArrayRef<Value *> destIndices, Value *numElements,
|
||||
Value *tagMemRef, ArrayRef<Value *> tagIndices,
|
||||
Value *stride = nullptr,
|
||||
Value *elementsPerStride = nullptr);
|
||||
|
||||
// Returns the source MemRefType for this DMA operation.
|
||||
const SSAValue *getSrcMemRef() const { return getOperand(0); }
|
||||
const Value *getSrcMemRef() const { return getOperand(0); }
|
||||
// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() const {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
|
||||
|
@ -375,7 +374,7 @@ public:
|
|||
}
|
||||
|
||||
// Returns the destination MemRefType for this DMA operations.
|
||||
const SSAValue *getDstMemRef() const {
|
||||
const Value *getDstMemRef() const {
|
||||
return getOperand(1 + getSrcMemRefRank());
|
||||
}
|
||||
// Returns the rank (number of indices) of the destination MemRefType.
|
||||
|
@ -398,12 +397,12 @@ public:
|
|||
}
|
||||
|
||||
// Returns the number of elements being transferred by this DMA operation.
|
||||
const SSAValue *getNumElements() const {
|
||||
const Value *getNumElements() const {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
|
||||
}
|
||||
|
||||
// Returns the Tag MemRef for this DMA operation.
|
||||
const SSAValue *getTagMemRef() const {
|
||||
const Value *getTagMemRef() const {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
|
||||
}
|
||||
// Returns the rank (number of indices) of the tag MemRefType.
|
||||
|
@ -453,21 +452,21 @@ public:
|
|||
1 + 1 + getTagMemRefRank();
|
||||
}
|
||||
|
||||
SSAValue *getStride() {
|
||||
Value *getStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1 - 1);
|
||||
}
|
||||
const SSAValue *getStride() const {
|
||||
const Value *getStride() const {
|
||||
return const_cast<DmaStartOp *>(this)->getStride();
|
||||
}
|
||||
|
||||
SSAValue *getNumElementsPerStride() {
|
||||
Value *getNumElementsPerStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1);
|
||||
}
|
||||
const SSAValue *getNumElementsPerStride() const {
|
||||
const Value *getNumElementsPerStride() const {
|
||||
return const_cast<DmaStartOp *>(this)->getNumElementsPerStride();
|
||||
}
|
||||
|
||||
|
@ -493,15 +492,14 @@ protected:
|
|||
class DmaWaitOp
|
||||
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices,
|
||||
SSAValue *numElements);
|
||||
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
|
||||
ArrayRef<Value *> tagIndices, Value *numElements);
|
||||
|
||||
static StringRef getOperationName() { return "dma_wait"; }
|
||||
|
||||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||
const SSAValue *getTagMemRef() const { return getOperand(0); }
|
||||
SSAValue *getTagMemRef() { return getOperand(0); }
|
||||
const Value *getTagMemRef() const { return getOperand(0); }
|
||||
Value *getTagMemRef() { return getOperand(0); }
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
llvm::iterator_range<Operation::const_operand_iterator>
|
||||
|
@ -516,7 +514,7 @@ public:
|
|||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
const SSAValue *getNumElements() const {
|
||||
const Value *getNumElements() const {
|
||||
return getOperand(1 + getTagMemRefRank());
|
||||
}
|
||||
|
||||
|
@ -545,11 +543,11 @@ class ExtractElementOp
|
|||
: public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult,
|
||||
OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *aggregate, ArrayRef<SSAValue *> indices = {});
|
||||
static void build(Builder *builder, OperationState *result, Value *aggregate,
|
||||
ArrayRef<Value *> indices = {});
|
||||
|
||||
SSAValue *getAggregate() { return getOperand(0); }
|
||||
const SSAValue *getAggregate() const { return getOperand(0); }
|
||||
Value *getAggregate() { return getOperand(0); }
|
||||
const Value *getAggregate() const { return getOperand(0); }
|
||||
|
||||
llvm::iterator_range<Operation::operand_iterator> getIndices() {
|
||||
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
|
||||
|
@ -583,12 +581,12 @@ class LoadOp
|
|||
: public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
||||
public:
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *memref,
|
||||
ArrayRef<SSAValue *> indices = {});
|
||||
static void build(Builder *builder, OperationState *result, Value *memref,
|
||||
ArrayRef<Value *> indices = {});
|
||||
|
||||
SSAValue *getMemRef() { return getOperand(0); }
|
||||
const SSAValue *getMemRef() const { return getOperand(0); }
|
||||
void setMemRef(SSAValue *value) { setOperand(0, value); }
|
||||
Value *getMemRef() { return getOperand(0); }
|
||||
const Value *getMemRef() const { return getOperand(0); }
|
||||
void setMemRef(Value *value) { setOperand(0, value); }
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
@ -705,19 +703,18 @@ class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
|
|||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "select"; }
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, SSAValue *trueValue,
|
||||
SSAValue *falseValue);
|
||||
static void build(Builder *builder, OperationState *result, Value *condition,
|
||||
Value *trueValue, Value *falseValue);
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
|
||||
SSAValue *getCondition() { return getOperand(0); }
|
||||
const SSAValue *getCondition() const { return getOperand(0); }
|
||||
SSAValue *getTrueValue() { return getOperand(1); }
|
||||
const SSAValue *getTrueValue() const { return getOperand(1); }
|
||||
SSAValue *getFalseValue() { return getOperand(2); }
|
||||
const SSAValue *getFalseValue() const { return getOperand(2); }
|
||||
Value *getCondition() { return getOperand(0); }
|
||||
const Value *getCondition() const { return getOperand(0); }
|
||||
Value *getTrueValue() { return getOperand(1); }
|
||||
const Value *getTrueValue() const { return getOperand(1); }
|
||||
Value *getFalseValue() { return getOperand(2); }
|
||||
const Value *getFalseValue() const { return getOperand(2); }
|
||||
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
@ -742,15 +739,15 @@ class StoreOp
|
|||
public:
|
||||
// Hooks to customize behavior of this op.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *valueToStore, SSAValue *memref,
|
||||
ArrayRef<SSAValue *> indices = {});
|
||||
Value *valueToStore, Value *memref,
|
||||
ArrayRef<Value *> indices = {});
|
||||
|
||||
SSAValue *getValueToStore() { return getOperand(0); }
|
||||
const SSAValue *getValueToStore() const { return getOperand(0); }
|
||||
Value *getValueToStore() { return getOperand(0); }
|
||||
const Value *getValueToStore() const { return getOperand(0); }
|
||||
|
||||
SSAValue *getMemRef() { return getOperand(1); }
|
||||
const SSAValue *getMemRef() const { return getOperand(1); }
|
||||
void setMemRef(SSAValue *value) { setOperand(1, value); }
|
||||
Value *getMemRef() { return getOperand(1); }
|
||||
const Value *getMemRef() const { return getOperand(1); }
|
||||
void setMemRef(Value *value) { setOperand(1, value); }
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
|
|
@ -98,26 +98,24 @@ public:
|
|||
static StringRef getOperationName() { return "vector_transfer_read"; }
|
||||
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
VectorType vectorType, SSAValue *srcMemRef,
|
||||
ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap,
|
||||
Optional<SSAValue *> paddingValue = None);
|
||||
VectorType vectorType, Value *srcMemRef,
|
||||
ArrayRef<Value *> srcIndices, AffineMap permutationMap,
|
||||
Optional<Value *> paddingValue = None);
|
||||
VectorType getResultType() const {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
}
|
||||
SSAValue *getVector() { return getResult(); }
|
||||
const SSAValue *getVector() const { return getResult(); }
|
||||
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
|
||||
const SSAValue *getMemRef() const {
|
||||
return getOperand(Offsets::MemRefOffset);
|
||||
}
|
||||
Value *getVector() { return getResult(); }
|
||||
const Value *getVector() const { return getResult(); }
|
||||
Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
|
||||
const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
|
||||
VectorType getVectorType() const { return getResultType(); }
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
llvm::iterator_range<Operation::operand_iterator> getIndices();
|
||||
llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
|
||||
Optional<SSAValue *> getPaddingValue();
|
||||
Optional<const SSAValue *> getPaddingValue() const;
|
||||
Optional<Value *> getPaddingValue();
|
||||
Optional<const Value *> getPaddingValue() const;
|
||||
AffineMap getPermutationMap() const;
|
||||
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
|
@ -169,20 +167,16 @@ class VectorTransferWriteOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "vector_transfer_write"; }
|
||||
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcVector, SSAValue *dstMemRef,
|
||||
ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap);
|
||||
SSAValue *getVector() { return getOperand(Offsets::VectorOffset); }
|
||||
const SSAValue *getVector() const {
|
||||
return getOperand(Offsets::VectorOffset);
|
||||
}
|
||||
static void build(Builder *builder, OperationState *result, Value *srcVector,
|
||||
Value *dstMemRef, ArrayRef<Value *> dstIndices,
|
||||
AffineMap permutationMap);
|
||||
Value *getVector() { return getOperand(Offsets::VectorOffset); }
|
||||
const Value *getVector() const { return getOperand(Offsets::VectorOffset); }
|
||||
VectorType getVectorType() const {
|
||||
return getVector()->getType().cast<VectorType>();
|
||||
}
|
||||
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
|
||||
const SSAValue *getMemRef() const {
|
||||
return getOperand(Offsets::MemRefOffset);
|
||||
}
|
||||
Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
|
||||
const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
|
||||
MemRefType getMemRefType() const {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
}
|
||||
|
@ -212,8 +206,8 @@ class VectorTypeCastOp
|
|||
: public Op<VectorTypeCastOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "vector_type_cast"; }
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcVector, Type dstType);
|
||||
static void build(Builder *builder, OperationState *result, Value *srcVector,
|
||||
Type dstType);
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
|
|
|
@ -35,10 +35,9 @@ namespace mlir {
|
|||
class ForStmt;
|
||||
class FuncBuilder;
|
||||
class Location;
|
||||
class MLValue;
|
||||
class Module;
|
||||
class OperationStmt;
|
||||
class SSAValue;
|
||||
|
||||
class Function;
|
||||
using CFGFunction = Function;
|
||||
|
||||
|
@ -52,12 +51,12 @@ using CFGFunction = Function;
|
|||
/// Returns true on success and false if the replacement is not possible
|
||||
/// (whenever a memref is used as an operand in a non-deferencing scenario). See
|
||||
/// comments at function definition for an example.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
|
||||
ArrayRef<SSAValue *> extraIndices = {},
|
||||
bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap::Null(),
|
||||
ArrayRef<SSAValue *> extraOperands = {},
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
const Statement *domStmtFilter = nullptr);
|
||||
|
||||
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
|
||||
|
@ -69,9 +68,9 @@ bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
|
|||
/// parameter 'results'. Returns the affine apply op created.
|
||||
OperationStmt *
|
||||
createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<OperationStmt *> affineApplyOps,
|
||||
SmallVectorImpl<SSAValue *> *results);
|
||||
SmallVectorImpl<Value *> *results);
|
||||
|
||||
/// Given an operation statement, inserts a new single affine apply operation,
|
||||
/// that is exclusively used by this operation statement, and that provides all
|
||||
|
@ -104,7 +103,7 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt);
|
|||
/// Forward substitutes results from 'AffineApplyOp' into any users which
|
||||
/// are also AffineApplyOps.
|
||||
// NOTE: This method may modify users of results of this operation.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions.
|
||||
// TODO(mlir-team): extend this for Value/ CFGFunctions.
|
||||
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
|
||||
|
||||
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.
|
||||
|
|
|
@ -489,11 +489,11 @@ bool mlir::getFlattenedAffineExprs(
|
|||
// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
|
||||
// the AffineApplyOp into any user AffineApplyOps.
|
||||
void mlir::getReachableAffineApplyOps(
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Value *> operands,
|
||||
SmallVectorImpl<OperationStmt *> &affineApplyOps) {
|
||||
struct State {
|
||||
// The ssa value for this node in the DFS traversal.
|
||||
MLValue *value;
|
||||
Value *value;
|
||||
// The operand index of 'value' to explore next during DFS traversal.
|
||||
unsigned operandIndex;
|
||||
};
|
||||
|
@ -557,8 +557,8 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
|
|||
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
|
||||
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
|
||||
FlatAffineConstraints *domain) {
|
||||
SmallVector<MLValue *, 4> indices(forStmts.begin(), forStmts.end());
|
||||
// Reset while associated MLValues in 'indices' to the domain.
|
||||
SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
|
||||
// Reset while associated Values in 'indices' to the domain.
|
||||
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
|
||||
for (auto *forStmt : forStmts) {
|
||||
// Add constraints from forStmt's bounds.
|
||||
|
@ -583,10 +583,10 @@ static bool getStmtIndexSet(const Statement *stmt,
|
|||
return getIndexSet(loops, indexSet);
|
||||
}
|
||||
|
||||
// ValuePositionMap manages the mapping from MLValues which represent dimension
|
||||
// ValuePositionMap manages the mapping from Values which represent dimension
|
||||
// and symbol identifiers from 'src' and 'dst' access functions to positions
|
||||
// in new space where some MLValues are kept separate (using addSrc/DstValue)
|
||||
// and some MLValues are merged (addSymbolValue).
|
||||
// in new space where some Values are kept separate (using addSrc/DstValue)
|
||||
// and some Values are merged (addSymbolValue).
|
||||
// Position lookups return the absolute position in the new space which
|
||||
// has the following format:
|
||||
//
|
||||
|
@ -595,7 +595,7 @@ static bool getStmtIndexSet(const Statement *stmt,
|
|||
// Note: access function non-IV dimension identifiers (that have 'dimension'
|
||||
// positions in the access function position space) are assigned as symbols
|
||||
// in the output position space. Convienience access functions which lookup
|
||||
// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
|
||||
// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
|
||||
// the common case of resolving positions for all access function operands.
|
||||
//
|
||||
// TODO(andydavis) Generalize this: could take a template parameter for
|
||||
|
@ -603,25 +603,25 @@ static bool getStmtIndexSet(const Statement *stmt,
|
|||
// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
|
||||
class ValuePositionMap {
|
||||
public:
|
||||
void addSrcValue(const MLValue *value) {
|
||||
void addSrcValue(const Value *value) {
|
||||
if (addValueAt(value, &srcDimPosMap, numSrcDims))
|
||||
++numSrcDims;
|
||||
}
|
||||
void addDstValue(const MLValue *value) {
|
||||
void addDstValue(const Value *value) {
|
||||
if (addValueAt(value, &dstDimPosMap, numDstDims))
|
||||
++numDstDims;
|
||||
}
|
||||
void addSymbolValue(const MLValue *value) {
|
||||
void addSymbolValue(const Value *value) {
|
||||
if (addValueAt(value, &symbolPosMap, numSymbols))
|
||||
++numSymbols;
|
||||
}
|
||||
unsigned getSrcDimOrSymPos(const MLValue *value) const {
|
||||
unsigned getSrcDimOrSymPos(const Value *value) const {
|
||||
return getDimOrSymPos(value, srcDimPosMap, 0);
|
||||
}
|
||||
unsigned getDstDimOrSymPos(const MLValue *value) const {
|
||||
unsigned getDstDimOrSymPos(const Value *value) const {
|
||||
return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
|
||||
}
|
||||
unsigned getSymPos(const MLValue *value) const {
|
||||
unsigned getSymPos(const Value *value) const {
|
||||
auto it = symbolPosMap.find(value);
|
||||
assert(it != symbolPosMap.end());
|
||||
return numSrcDims + numDstDims + it->second;
|
||||
|
@ -633,8 +633,7 @@ public:
|
|||
unsigned getNumSymbols() const { return numSymbols; }
|
||||
|
||||
private:
|
||||
bool addValueAt(const MLValue *value,
|
||||
DenseMap<const MLValue *, unsigned> *posMap,
|
||||
bool addValueAt(const Value *value, DenseMap<const Value *, unsigned> *posMap,
|
||||
unsigned position) {
|
||||
auto it = posMap->find(value);
|
||||
if (it == posMap->end()) {
|
||||
|
@ -643,8 +642,8 @@ private:
|
|||
}
|
||||
return false;
|
||||
}
|
||||
unsigned getDimOrSymPos(const MLValue *value,
|
||||
const DenseMap<const MLValue *, unsigned> &dimPosMap,
|
||||
unsigned getDimOrSymPos(const Value *value,
|
||||
const DenseMap<const Value *, unsigned> &dimPosMap,
|
||||
unsigned dimPosOffset) const {
|
||||
auto it = dimPosMap.find(value);
|
||||
if (it != dimPosMap.end()) {
|
||||
|
@ -658,25 +657,25 @@ private:
|
|||
unsigned numSrcDims = 0;
|
||||
unsigned numDstDims = 0;
|
||||
unsigned numSymbols = 0;
|
||||
DenseMap<const MLValue *, unsigned> srcDimPosMap;
|
||||
DenseMap<const MLValue *, unsigned> dstDimPosMap;
|
||||
DenseMap<const MLValue *, unsigned> symbolPosMap;
|
||||
DenseMap<const Value *, unsigned> srcDimPosMap;
|
||||
DenseMap<const Value *, unsigned> dstDimPosMap;
|
||||
DenseMap<const Value *, unsigned> symbolPosMap;
|
||||
};
|
||||
|
||||
// Builds a map from MLValue to identifier position in a new merged identifier
|
||||
// Builds a map from Value to identifier position in a new merged identifier
|
||||
// list, which is the result of merging dim/symbol lists from src/dst
|
||||
// iteration domains. The format of the new merged list is as follows:
|
||||
//
|
||||
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
|
||||
//
|
||||
// This method populates 'valuePosMap' with mappings from operand MLValues in
|
||||
// This method populates 'valuePosMap' with mappings from operand Values in
|
||||
// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
|
||||
// to the position of these values in the merged list.
|
||||
static void buildDimAndSymbolPositionMaps(
|
||||
const FlatAffineConstraints &srcDomain,
|
||||
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
|
||||
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) {
|
||||
auto updateValuePosMap = [&](ArrayRef<MLValue *> values, bool isSrc) {
|
||||
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
auto *value = values[i];
|
||||
if (!isa<ForStmt>(values[i]))
|
||||
|
@ -688,7 +687,7 @@ static void buildDimAndSymbolPositionMaps(
|
|||
}
|
||||
};
|
||||
|
||||
SmallVector<MLValue *, 4> srcValues, destValues;
|
||||
SmallVector<Value *, 4> srcValues, destValues;
|
||||
srcDomain.getIdValues(&srcValues);
|
||||
dstDomain.getIdValues(&destValues);
|
||||
|
||||
|
@ -702,17 +701,10 @@ static void buildDimAndSymbolPositionMaps(
|
|||
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
|
||||
}
|
||||
|
||||
static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
|
||||
const MLValue *value) {
|
||||
auto it = posMap.find(value);
|
||||
assert(it != posMap.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
|
||||
// 'dependenceDomain'.
|
||||
// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
|
||||
// srcDomain/dstDomain MLValue maps.
|
||||
// srcDomain/dstDomain Value maps.
|
||||
static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
|
||||
const FlatAffineConstraints &dstDomain,
|
||||
const ValuePositionMap &valuePosMap,
|
||||
|
@ -790,10 +782,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
|
|||
unsigned numResults = srcMap.getNumResults();
|
||||
|
||||
unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
|
||||
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
|
||||
ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
|
||||
|
||||
unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
|
||||
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
|
||||
ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
|
||||
|
||||
std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
|
||||
std::vector<SmallVector<int64_t, 8>> destFlatExprs;
|
||||
|
@ -848,7 +840,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
|
|||
}
|
||||
|
||||
// Add equality constraints for any operands that are defined by constant ops.
|
||||
auto addEqForConstOperands = [&](ArrayRef<const MLValue *> operands) {
|
||||
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (isa<ForStmt>(operands[i]))
|
||||
continue;
|
||||
|
@ -1095,7 +1087,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
|
|||
// upper/lower loop bounds for each ForStmt in the loop nest associated
|
||||
// with each access.
|
||||
// *) Build dimension and symbol position maps for each access, which map
|
||||
// MLValues from access functions and iteration domains to their position
|
||||
// Values from access functions and iteration domains to their position
|
||||
// in the merged constraint system built by this method.
|
||||
//
|
||||
// This method builds a constraint system with the following column format:
|
||||
|
@ -1202,7 +1194,7 @@ bool mlir::checkMemrefAccessDependence(
|
|||
return false;
|
||||
}
|
||||
// Build dim and symbol position maps for each access from access operand
|
||||
// MLValue to position in merged contstraint system.
|
||||
// Value to position in merged contstraint system.
|
||||
ValuePositionMap valuePosMap;
|
||||
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
|
||||
dstAccessMap, &valuePosMap);
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
@ -238,23 +237,23 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
|
|||
AffineValueMap::AffineValueMap(const AffineApplyOp &op)
|
||||
: map(op.getAffineMap()) {
|
||||
for (auto *operand : op.getOperands())
|
||||
operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand)));
|
||||
operands.push_back(const_cast<Value *>(operand));
|
||||
for (unsigned i = 0, e = op.getNumResults(); i < e; i++)
|
||||
results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i))));
|
||||
results.push_back(const_cast<Value *>(op.getResult(i)));
|
||||
}
|
||||
|
||||
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
|
||||
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands)
|
||||
: map(map) {
|
||||
for (MLValue *operand : operands) {
|
||||
for (Value *operand : operands) {
|
||||
this->operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) {
|
||||
void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands) {
|
||||
this->operands.clear();
|
||||
this->results.clear();
|
||||
this->map.reset(map);
|
||||
for (MLValue *operand : operands) {
|
||||
for (Value *operand : operands) {
|
||||
this->operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
@ -275,7 +274,7 @@ void AffineValueMap::forwardSubstituteSingle(const AffineApplyOp &inputOp,
|
|||
|
||||
// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
|
||||
// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
|
||||
static bool findIndex(MLValue *valueToMatch, ArrayRef<MLValue *> valuesToSearch,
|
||||
static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
|
||||
unsigned indexStart, unsigned *indexOfMatch) {
|
||||
unsigned size = valuesToSearch.size();
|
||||
for (unsigned i = indexStart; i < size; ++i) {
|
||||
|
@ -324,8 +323,7 @@ void AffineValueMap::forwardSubstitute(
|
|||
for (unsigned j = 0; j < inputNumResults; ++j) {
|
||||
if (!inputResultsToSubstitute[j])
|
||||
continue;
|
||||
if (operands[i] ==
|
||||
cast<MLValue>(const_cast<SSAValue *>(inputOp.getResult(j)))) {
|
||||
if (operands[i] == const_cast<Value *>(inputOp.getResult(j))) {
|
||||
currOperandToInputResult[i] = j;
|
||||
inputResultsUsed.insert(j);
|
||||
}
|
||||
|
@ -365,7 +363,7 @@ void AffineValueMap::forwardSubstitute(
|
|||
}
|
||||
|
||||
// Build new output operands list and map update.
|
||||
SmallVector<MLValue *, 4> outputOperands;
|
||||
SmallVector<Value *, 4> outputOperands;
|
||||
unsigned outputOperandPosition = 0;
|
||||
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults());
|
||||
|
||||
|
@ -385,8 +383,7 @@ void AffineValueMap::forwardSubstitute(
|
|||
if (inputPositionsUsed.count(i) == 0)
|
||||
continue;
|
||||
// Check if input operand has a dup in current operand list.
|
||||
auto *inputOperand =
|
||||
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
|
||||
auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
|
||||
unsigned outputIndex;
|
||||
if (findIndex(inputOperand, outputOperands, /*indexStart=*/0,
|
||||
&outputIndex)) {
|
||||
|
@ -418,8 +415,7 @@ void AffineValueMap::forwardSubstitute(
|
|||
continue;
|
||||
unsigned inputSymbolPosition = i - inputNumDims;
|
||||
// Check if input operand has a dup in current operand list.
|
||||
auto *inputOperand =
|
||||
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
|
||||
auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
|
||||
// Find output operand index of 'inputOperand' dup.
|
||||
unsigned outputIndex;
|
||||
// Start at index 'outputNumDims' so that only symbol operands are searched.
|
||||
|
@ -451,7 +447,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
|
|||
|
||||
/// This method uses the invariant that operands are always positionally aligned
|
||||
/// with the AffineDimExpr in the underlying AffineMap.
|
||||
bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
|
||||
bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
|
||||
unsigned index;
|
||||
findIndex(value, operands, /*indexStart=*/0, &index);
|
||||
auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
|
||||
|
@ -460,12 +456,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
|
|||
return expr.isFunctionOfDim(index);
|
||||
}
|
||||
|
||||
SSAValue *AffineValueMap::getOperand(unsigned i) const {
|
||||
return static_cast<SSAValue *>(operands[i]);
|
||||
Value *AffineValueMap::getOperand(unsigned i) const {
|
||||
return static_cast<Value *>(operands[i]);
|
||||
}
|
||||
|
||||
ArrayRef<MLValue *> AffineValueMap::getOperands() const {
|
||||
return ArrayRef<MLValue *>(operands);
|
||||
ArrayRef<Value *> AffineValueMap::getOperands() const {
|
||||
return ArrayRef<Value *>(operands);
|
||||
}
|
||||
|
||||
AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
|
||||
|
@ -546,7 +542,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
|
|||
unsigned newNumReservedCols,
|
||||
unsigned newNumDims, unsigned newNumSymbols,
|
||||
unsigned newNumLocals,
|
||||
ArrayRef<MLValue *> idArgs) {
|
||||
ArrayRef<Value *> idArgs) {
|
||||
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
|
||||
"minimum 1 column");
|
||||
numReservedCols = newNumReservedCols;
|
||||
|
@ -570,7 +566,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
|
|||
|
||||
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
|
||||
unsigned newNumLocals,
|
||||
ArrayRef<MLValue *> idArgs) {
|
||||
ArrayRef<Value *> idArgs) {
|
||||
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
|
||||
newNumSymbols, newNumLocals, idArgs);
|
||||
}
|
||||
|
@ -597,17 +593,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) {
|
|||
addId(IdKind::Local, pos);
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) {
|
||||
void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
|
||||
addId(IdKind::Dimension, pos, id);
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) {
|
||||
void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
|
||||
addId(IdKind::Symbol, pos, id);
|
||||
}
|
||||
|
||||
/// Adds a dimensional identifier. The added column is initialized to
|
||||
/// zero.
|
||||
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) {
|
||||
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
|
||||
if (kind == IdKind::Dimension) {
|
||||
assert(pos <= getNumDimIds());
|
||||
} else if (kind == IdKind::Symbol) {
|
||||
|
@ -755,7 +751,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
|
|||
// Dims and symbols.
|
||||
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
|
||||
unsigned loc;
|
||||
bool ret = findId(*cast<MLValue>(vMap->getOperand(i)), &loc);
|
||||
bool ret = findId(*vMap->getOperand(i), &loc);
|
||||
assert(ret && "value map's id can't be found");
|
||||
(void)ret;
|
||||
// We need to negate 'eq[r]' since the newly added dimension is going to
|
||||
|
@ -1231,7 +1227,7 @@ void FlatAffineConstraints::addUpperBound(ArrayRef<int64_t> expr,
|
|||
}
|
||||
}
|
||||
|
||||
bool FlatAffineConstraints::findId(const MLValue &id, unsigned *pos) const {
|
||||
bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const {
|
||||
unsigned i = 0;
|
||||
for (const auto &mayBeId : ids) {
|
||||
if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
|
||||
|
@ -1253,8 +1249,8 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
|
|||
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
||||
unsigned pos;
|
||||
// Pre-condition for this method.
|
||||
if (!findId(*cast<MLValue>(&forStmt), &pos)) {
|
||||
assert(0 && "MLValue not found");
|
||||
if (!findId(forStmt, &pos)) {
|
||||
assert(0 && "Value not found");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1270,7 +1266,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
|||
unsigned loc;
|
||||
if (!findId(*operand, &loc)) {
|
||||
if (operand->isValidSymbol()) {
|
||||
addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand));
|
||||
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
|
||||
loc = getNumDimIds() + getNumSymbolIds() - 1;
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opStmt = operand->getDefiningStmt()) {
|
||||
|
@ -1279,7 +1275,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
addDimId(getNumDimIds(), const_cast<MLValue *>(operand));
|
||||
addDimId(getNumDimIds(), const_cast<Value *>(operand));
|
||||
loc = getNumDimIds() - 1;
|
||||
}
|
||||
}
|
||||
|
@ -1352,7 +1348,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
|
|||
|
||||
/// Sets the specified identifer to a constant value; asserts if the id is not
|
||||
/// found.
|
||||
void FlatAffineConstraints::setIdToConstant(const MLValue &id, int64_t val) {
|
||||
void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) {
|
||||
unsigned pos;
|
||||
if (!findId(id, &pos))
|
||||
// This is a pre-condition for this method.
|
||||
|
@ -1572,7 +1568,7 @@ void FlatAffineConstraints::print(raw_ostream &os) const {
|
|||
if (ids[i] == None)
|
||||
os << "None ";
|
||||
else
|
||||
os << "MLValue ";
|
||||
os << "Value ";
|
||||
}
|
||||
os << " const)\n";
|
||||
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
|
||||
|
@ -1779,7 +1775,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
|
|||
unsigned newNumDims = dimsSymbols.first;
|
||||
unsigned newNumSymbols = dimsSymbols.second;
|
||||
|
||||
SmallVector<Optional<MLValue *>, 8> newIds;
|
||||
SmallVector<Optional<Value *>, 8> newIds;
|
||||
newIds.reserve(numIds - 1);
|
||||
newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos);
|
||||
newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end());
|
||||
|
@ -1942,7 +1938,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
|
|||
normalizeConstraintsByGCD();
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::projectOut(MLValue *id) {
|
||||
void FlatAffineConstraints::projectOut(Value *id) {
|
||||
unsigned pos;
|
||||
bool ret = findId(*id, &pos);
|
||||
assert(ret);
|
||||
|
|
|
@ -70,7 +70,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a,
|
|||
}
|
||||
|
||||
/// Return true if value A properly dominates instruction B.
|
||||
bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) {
|
||||
bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
|
||||
if (auto *aInst = a->getDefiningInst())
|
||||
return properlyDominates(aInst, b);
|
||||
|
||||
|
|
|
@ -124,14 +124,14 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
|||
return tripCountExpr.getLargestKnownDivisor();
|
||||
}
|
||||
|
||||
bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
|
||||
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
|
||||
assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
|
||||
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
|
||||
SmallVector<OperationStmt *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps({const_cast<MLValue *>(&index)}, affineApplyOps);
|
||||
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
|
||||
|
||||
if (affineApplyOps.empty()) {
|
||||
// Pointer equality test because of MLValue pointer semantics.
|
||||
// Pointer equality test because of Value pointer semantics.
|
||||
return &index != &iv;
|
||||
}
|
||||
|
||||
|
@ -155,13 +155,13 @@ bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
|
|||
}
|
||||
assert(idx < std::numeric_limits<unsigned>::max());
|
||||
return !AffineValueMap(*composeOp)
|
||||
.isFunctionOf(idx, &const_cast<MLValue &>(iv));
|
||||
.isFunctionOf(idx, &const_cast<Value &>(iv));
|
||||
}
|
||||
|
||||
llvm::DenseSet<const MLValue *>
|
||||
mlir::getInvariantAccesses(const MLValue &iv,
|
||||
llvm::ArrayRef<const MLValue *> indices) {
|
||||
llvm::DenseSet<const MLValue *> res;
|
||||
llvm::DenseSet<const Value *>
|
||||
mlir::getInvariantAccesses(const Value &iv,
|
||||
llvm::ArrayRef<const Value *> indices) {
|
||||
llvm::DenseSet<const Value *> res;
|
||||
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
|
||||
auto *val = indices[idx];
|
||||
if (isAccessInvariant(iv, *val)) {
|
||||
|
@ -191,7 +191,7 @@ mlir::getInvariantAccesses(const MLValue &iv,
|
|||
///
|
||||
// TODO(ntv): check strides.
|
||||
template <typename LoadOrStoreOp>
|
||||
static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
|
||||
static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp,
|
||||
unsigned fastestVaryingDim) {
|
||||
static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value ||
|
||||
std::is_same<LoadOrStoreOp, StoreOp>::value,
|
||||
|
@ -220,7 +220,7 @@ static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
|
|||
if (fastestVaryingDim == (numIndices - 1) - d++) {
|
||||
continue;
|
||||
}
|
||||
if (!isAccessInvariant(iv, cast<MLValue>(*index))) {
|
||||
if (!isAccessInvariant(iv, *index)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -316,7 +316,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
|
|||
// outside).
|
||||
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
|
||||
const MLValue *result = opStmt->getResult(i);
|
||||
const Value *result = opStmt->getResult(i);
|
||||
for (const StmtOperand &use : result->getUses()) {
|
||||
// If an ancestor statement doesn't lie in the block of forStmt, there
|
||||
// is no shift to check.
|
||||
|
|
|
@ -70,7 +70,7 @@ static void addMemRefAccessIndices(
|
|||
MemRefType memrefType, MemRefAccess *access) {
|
||||
access->indices.reserve(memrefType.getRank());
|
||||
for (auto *index : opIndices) {
|
||||
access->indices.push_back(cast<MLValue>(const_cast<SSAValue *>(index)));
|
||||
access->indices.push_back(const_cast<mlir::Value *>(index));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,13 +79,13 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt,
|
|||
MemRefAccess *access) {
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
|
||||
access->memref = cast<MLValue>(loadOp->getMemRef());
|
||||
access->memref = loadOp->getMemRef();
|
||||
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
|
||||
access);
|
||||
} else {
|
||||
assert(loadOrStoreOpStmt->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
|
||||
access->memref = cast<MLValue>(storeOp->getMemRef());
|
||||
access->memref = storeOp->getMemRef();
|
||||
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
|
||||
access);
|
||||
}
|
||||
|
|
|
@ -150,21 +150,21 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
|
|||
OpPointer<LoadOp> loadOp;
|
||||
OpPointer<StoreOp> storeOp;
|
||||
unsigned rank;
|
||||
SmallVector<MLValue *, 4> indices;
|
||||
SmallVector<Value *, 4> indices;
|
||||
|
||||
if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
|
||||
rank = loadOp->getMemRefType().getRank();
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
indices.push_back(cast<MLValue>(index));
|
||||
indices.push_back(index);
|
||||
}
|
||||
region->memref = cast<MLValue>(loadOp->getMemRef());
|
||||
region->memref = loadOp->getMemRef();
|
||||
region->setWrite(false);
|
||||
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
|
||||
rank = storeOp->getMemRefType().getRank();
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
indices.push_back(cast<MLValue>(index));
|
||||
indices.push_back(index);
|
||||
}
|
||||
region->memref = cast<MLValue>(storeOp->getMemRef());
|
||||
region->memref = storeOp->getMemRef();
|
||||
region->setWrite(true);
|
||||
} else {
|
||||
return false;
|
||||
|
@ -201,7 +201,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
|
|||
return false;
|
||||
} else {
|
||||
// Has to be a valid symbol.
|
||||
auto *symbol = cast<MLValue>(accessValueMap.getOperand(i));
|
||||
auto *symbol = accessValueMap.getOperand(i);
|
||||
assert(symbol->isValidSymbol());
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opStmt = symbol->getDefiningStmt()) {
|
||||
|
@ -405,7 +405,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
|
||||
// Solve for src IVs in terms of dst IVs, symbols and constants.
|
||||
SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null());
|
||||
std::vector<SmallVector<MLValue *, 2>> srcIvOperands(srcLoopNestSize);
|
||||
std::vector<SmallVector<Value *, 2>> srcIvOperands(srcLoopNestSize);
|
||||
for (unsigned i = 0; i < srcLoopNestSize; ++i) {
|
||||
// Skip IVs which are greater than requested loop depth.
|
||||
if (i >= srcLoopDepth) {
|
||||
|
@ -442,7 +442,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
srcIvOperands[i].push_back(dstLoopNest[dimId - 1]);
|
||||
}
|
||||
// TODO(andydavis) Add symbols from the access function. Ideally, we
|
||||
// should be able to query the constaint system for the MLValue associated
|
||||
// should be able to query the constaint system for the Value associated
|
||||
// with a symbol identifiers in 'nonZeroSymbolIds'.
|
||||
}
|
||||
|
||||
|
@ -454,7 +454,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
|
||||
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
|
||||
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
|
||||
|
||||
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
|
||||
|
|
|
@ -108,7 +108,7 @@ static AffineMap makePermutationMap(
|
|||
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
|
||||
using functional::makePtrDynCaster;
|
||||
using functional::map;
|
||||
auto unwrappedIndices = map(makePtrDynCaster<SSAValue, MLValue>(), indices);
|
||||
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
|
||||
SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
|
||||
getAffineConstantExpr(0, context));
|
||||
for (auto kvp : enclosingLoopToVectorDim) {
|
||||
|
|
|
@ -277,7 +277,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
|
|||
/// Walk all of the code in this MLFunc and verify that the operands of any
|
||||
/// operations are properly dominated by their definitions.
|
||||
bool MLFuncVerifier::verifyDominance() {
|
||||
using HashTable = llvm::ScopedHashTable<const SSAValue *, bool>;
|
||||
using HashTable = llvm::ScopedHashTable<const Value *, bool>;
|
||||
HashTable liveValues;
|
||||
HashTable::ScopeTy topScope(liveValues);
|
||||
|
||||
|
|
|
@ -38,7 +38,6 @@
|
|||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void Identifier::print(raw_ostream &os) const { os << str(); }
|
||||
|
@ -967,7 +966,7 @@ public:
|
|||
void printFunctionAttributes(const Function *func) {
|
||||
return ModulePrinter::printFunctionAttributes(func);
|
||||
}
|
||||
void printOperand(const SSAValue *value) { printValueID(value); }
|
||||
void printOperand(const Value *value) { printValueID(value); }
|
||||
|
||||
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<const char *> elidedAttrs = {}) {
|
||||
|
@ -977,7 +976,7 @@ public:
|
|||
enum { nameSentinel = ~0U };
|
||||
|
||||
protected:
|
||||
void numberValueID(const SSAValue *value) {
|
||||
void numberValueID(const Value *value) {
|
||||
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
||||
|
||||
SmallString<32> specialNameBuffer;
|
||||
|
@ -1004,7 +1003,7 @@ protected:
|
|||
|
||||
if (specialNameBuffer.empty()) {
|
||||
switch (value->getKind()) {
|
||||
case SSAValueKind::BlockArgument:
|
||||
case Value::Kind::BlockArgument:
|
||||
// If this is an argument to the function, give it an 'arg' name.
|
||||
if (auto *block = cast<BlockArgument>(value)->getOwner())
|
||||
if (auto *fn = block->getFunction())
|
||||
|
@ -1015,12 +1014,12 @@ protected:
|
|||
// Otherwise number it normally.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case SSAValueKind::StmtResult:
|
||||
case Value::Kind::StmtResult:
|
||||
// This is an uninteresting result, give it a boring number and be
|
||||
// done with it.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case SSAValueKind::ForStmt:
|
||||
case Value::Kind::ForStmt:
|
||||
specialName << 'i' << nextLoopID++;
|
||||
break;
|
||||
}
|
||||
|
@ -1052,7 +1051,7 @@ protected:
|
|||
}
|
||||
}
|
||||
|
||||
void printValueID(const SSAValue *value, bool printResultNo = true) const {
|
||||
void printValueID(const Value *value, bool printResultNo = true) const {
|
||||
int resultNo = -1;
|
||||
auto lookupValue = value;
|
||||
|
||||
|
@ -1093,8 +1092,8 @@ protected:
|
|||
private:
|
||||
/// This is the value ID for each SSA value in the current function. If this
|
||||
/// returns ~0, then the valueID has an entry in valueNames.
|
||||
DenseMap<const SSAValue *, unsigned> valueIDs;
|
||||
DenseMap<const SSAValue *, StringRef> valueNames;
|
||||
DenseMap<const Value *, unsigned> valueIDs;
|
||||
DenseMap<const Value *, StringRef> valueNames;
|
||||
|
||||
/// This keeps track of all of the non-numeric names that are in flight,
|
||||
/// allowing us to check for duplicates.
|
||||
|
@ -1135,7 +1134,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
|
|||
os << "\"(";
|
||||
|
||||
interleaveComma(op->getOperands(),
|
||||
[&](const SSAValue *value) { printValueID(value); });
|
||||
[&](const Value *value) { printValueID(value); });
|
||||
|
||||
os << ')';
|
||||
auto attrs = op->getAttrs();
|
||||
|
@ -1144,16 +1143,15 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
|
|||
// Print the type signature of the operation.
|
||||
os << " : (";
|
||||
interleaveComma(op->getOperands(),
|
||||
[&](const SSAValue *value) { printType(value->getType()); });
|
||||
[&](const Value *value) { printType(value->getType()); });
|
||||
os << ") -> ";
|
||||
|
||||
if (op->getNumResults() == 1) {
|
||||
printType(op->getResult(0)->getType());
|
||||
} else {
|
||||
os << '(';
|
||||
interleaveComma(op->getResults(), [&](const SSAValue *result) {
|
||||
printType(result->getType());
|
||||
});
|
||||
interleaveComma(op->getResults(),
|
||||
[&](const Value *result) { printType(result->getType()); });
|
||||
os << ')';
|
||||
}
|
||||
}
|
||||
|
@ -1297,11 +1295,10 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) {
|
|||
|
||||
os << '(';
|
||||
interleaveComma(range,
|
||||
[this](const SSAValue *operand) { printValueID(operand); });
|
||||
[this](const Value *operand) { printValueID(operand); });
|
||||
os << " : ";
|
||||
interleaveComma(range, [this](const SSAValue *operand) {
|
||||
printType(operand->getType());
|
||||
});
|
||||
interleaveComma(
|
||||
range, [this](const Value *operand) { printType(operand->getType()); });
|
||||
os << ')';
|
||||
}
|
||||
|
||||
|
@ -1576,20 +1573,20 @@ void IntegerSet::print(raw_ostream &os) const {
|
|||
ModulePrinter(os, state).printIntegerSet(*this);
|
||||
}
|
||||
|
||||
void SSAValue::print(raw_ostream &os) const {
|
||||
void Value::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case SSAValueKind::BlockArgument:
|
||||
case Value::Kind::BlockArgument:
|
||||
// TODO: Improve this.
|
||||
os << "<block argument>\n";
|
||||
return;
|
||||
case SSAValueKind::StmtResult:
|
||||
case Value::Kind::StmtResult:
|
||||
return getDefiningStmt()->print(os);
|
||||
case SSAValueKind::ForStmt:
|
||||
case Value::Kind::ForStmt:
|
||||
return cast<ForStmt>(this)->print(os);
|
||||
}
|
||||
}
|
||||
|
||||
void SSAValue::dump() const { print(llvm::errs()); }
|
||||
void Value::dump() const { print(llvm::errs()); }
|
||||
|
||||
void Instruction::print(raw_ostream &os) const {
|
||||
auto *function = getFunction();
|
||||
|
|
|
@ -281,7 +281,7 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
|
|||
// If we are supposed to insert before a specific block, do so, otherwise add
|
||||
// the block to the end of the function.
|
||||
if (insertBefore)
|
||||
function->getBlocks().insert(CFGFunction::iterator(insertBefore), b);
|
||||
function->getBlocks().insert(Function::iterator(insertBefore), b);
|
||||
else
|
||||
function->push_back(b);
|
||||
|
||||
|
@ -291,16 +291,9 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
|
|||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
|
||||
SmallVector<CFGValue *, 8> operands;
|
||||
operands.reserve(state.operands.size());
|
||||
// Allow null operands as they act as sentinal barriers between successor
|
||||
// operand lists.
|
||||
for (auto elt : state.operands)
|
||||
operands.push_back(cast_or_null<CFGValue>(elt));
|
||||
|
||||
auto *op =
|
||||
OperationInst::create(state.location, state.name, operands, state.types,
|
||||
state.attributes, state.successors, context);
|
||||
auto *op = OperationInst::create(state.location, state.name, state.operands,
|
||||
state.types, state.attributes,
|
||||
state.successors, context);
|
||||
block->getStatements().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
@ -311,23 +304,17 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
|
|||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
operands.reserve(state.operands.size());
|
||||
for (auto elt : state.operands)
|
||||
operands.push_back(cast<MLValue>(elt));
|
||||
|
||||
auto *op =
|
||||
OperationStmt::create(state.location, state.name, operands, state.types,
|
||||
state.attributes, state.successors, context);
|
||||
auto *op = OperationStmt::create(state.location, state.name, state.operands,
|
||||
state.types, state.attributes,
|
||||
state.successors, context);
|
||||
block->getStatements().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(Location location,
|
||||
ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap lbMap,
|
||||
ArrayRef<MLValue *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
ArrayRef<Value *> lbOperands, AffineMap lbMap,
|
||||
ArrayRef<Value *> ubOperands, AffineMap ubMap,
|
||||
int64_t step) {
|
||||
auto *stmt =
|
||||
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
|
@ -341,7 +328,7 @@ ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
|||
return createFor(location, {}, lbMap, {}, ubMap, step);
|
||||
}
|
||||
|
||||
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<MLValue *> operands,
|
||||
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
auto *stmt = IfStmt::create(location, operands, set);
|
||||
block->getStatements().insert(insertPoint, stmt);
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -54,7 +54,7 @@ void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
|
|||
// dimension operands parsed.
|
||||
// Returns 'false' on success and 'true' on error.
|
||||
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
|
||||
SmallVector<SSAValue *, 4> &operands,
|
||||
SmallVector<Value *, 4> &operands,
|
||||
unsigned &numDims) {
|
||||
SmallVector<OpAsmParser::OperandType, 8> opInfos;
|
||||
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
|
||||
|
@ -76,7 +76,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AffineApplyOp::build(Builder *builder, OperationState *result,
|
||||
AffineMap map, ArrayRef<SSAValue *> operands) {
|
||||
AffineMap map, ArrayRef<Value *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->types.append(map.getNumResults(), builder->getIndexType());
|
||||
result->addAttribute("map", builder->getAffineMapAttr(map));
|
||||
|
@ -133,24 +133,22 @@ bool AffineApplyOp::verify() const {
|
|||
}
|
||||
|
||||
// The result of the affine apply operation can be used as a dimension id if it
|
||||
// is a CFG value or if it is an MLValue, and all the operands are valid
|
||||
// is a CFG value or if it is an Value, and all the operands are valid
|
||||
// dimension ids.
|
||||
bool AffineApplyOp::isValidDim() const {
|
||||
for (auto *op : getOperands()) {
|
||||
if (auto *v = dyn_cast<MLValue>(op))
|
||||
if (!v->isValidDim())
|
||||
return false;
|
||||
if (!op->isValidDim())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// The result of the affine apply operation can be used as a symbol if it is
|
||||
// a CFG value or if it is an MLValue, and all the operands are symbols.
|
||||
// a CFG value or if it is an Value, and all the operands are symbols.
|
||||
bool AffineApplyOp::isValidSymbol() const {
|
||||
for (auto *op : getOperands()) {
|
||||
if (auto *v = dyn_cast<MLValue>(op))
|
||||
if (!v->isValidSymbol())
|
||||
return false;
|
||||
if (!op->isValidSymbol())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -170,13 +168,13 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
ArrayRef<Value *> operands) {
|
||||
result->addSuccessor(dest, operands);
|
||||
}
|
||||
|
||||
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
BasicBlock *dest;
|
||||
SmallVector<SSAValue *, 4> destOperands;
|
||||
SmallVector<Value *, 4> destOperands;
|
||||
if (parser->parseSuccessorAndUseList(dest, destOperands))
|
||||
return true;
|
||||
result->addSuccessor(dest, destOperands);
|
||||
|
@ -212,17 +210,16 @@ void BranchOp::eraseOperand(unsigned index) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CondBranchOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, BasicBlock *trueDest,
|
||||
ArrayRef<SSAValue *> trueOperands,
|
||||
BasicBlock *falseDest,
|
||||
ArrayRef<SSAValue *> falseOperands) {
|
||||
Value *condition, BasicBlock *trueDest,
|
||||
ArrayRef<Value *> trueOperands, BasicBlock *falseDest,
|
||||
ArrayRef<Value *> falseOperands) {
|
||||
result->addOperands(condition);
|
||||
result->addSuccessor(trueDest, trueOperands);
|
||||
result->addSuccessor(falseDest, falseOperands);
|
||||
}
|
||||
|
||||
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<SSAValue *, 4> destOperands;
|
||||
SmallVector<Value *, 4> destOperands;
|
||||
BasicBlock *dest;
|
||||
OpAsmParser::OperandType condInfo;
|
||||
|
||||
|
@ -446,7 +443,7 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ReturnOp::build(Builder *builder, OperationState *result,
|
||||
ArrayRef<SSAValue *> results) {
|
||||
ArrayRef<Value *> results) {
|
||||
result->addOperands(results);
|
||||
}
|
||||
|
||||
|
@ -465,9 +462,10 @@ void ReturnOp::print(OpAsmPrinter *p) const {
|
|||
*p << ' ';
|
||||
p->printOperands(operand_begin(), operand_end());
|
||||
*p << " : ";
|
||||
interleave(operand_begin(), operand_end(),
|
||||
[&](const SSAValue *e) { p->printType(e->getType()); },
|
||||
[&]() { *p << ", "; });
|
||||
interleave(
|
||||
operand_begin(), operand_end(),
|
||||
[&](const Value *e) { p->printType(e->getType()); },
|
||||
[&]() { *p << ", "; });
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Form the OperationName for an op with the specified string. This either is
|
||||
|
@ -96,13 +95,13 @@ unsigned Operation::getNumOperands() const {
|
|||
return llvm::cast<OperationStmt>(this)->getNumOperands();
|
||||
}
|
||||
|
||||
SSAValue *Operation::getOperand(unsigned idx) {
|
||||
Value *Operation::getOperand(unsigned idx) {
|
||||
return llvm::cast<OperationStmt>(this)->getOperand(idx);
|
||||
}
|
||||
|
||||
void Operation::setOperand(unsigned idx, SSAValue *value) {
|
||||
void Operation::setOperand(unsigned idx, Value *value) {
|
||||
auto *stmt = llvm::cast<OperationStmt>(this);
|
||||
stmt->setOperand(idx, llvm::cast<MLValue>(value));
|
||||
stmt->setOperand(idx, value);
|
||||
}
|
||||
|
||||
/// Return the number of results this operation has.
|
||||
|
@ -111,7 +110,7 @@ unsigned Operation::getNumResults() const {
|
|||
}
|
||||
|
||||
/// Return the indicated result.
|
||||
SSAValue *Operation::getResult(unsigned idx) {
|
||||
Value *Operation::getResult(unsigned idx) {
|
||||
return llvm::cast<OperationStmt>(this)->getResult(idx);
|
||||
}
|
||||
|
||||
|
@ -585,8 +584,8 @@ bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
|
|||
// These functions are out-of-line implementations of the methods in BinaryOp,
|
||||
// which avoids them being template instantiated/duplicated.
|
||||
|
||||
void impl::buildBinaryOp(Builder *builder, OperationState *result,
|
||||
SSAValue *lhs, SSAValue *rhs) {
|
||||
void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs) {
|
||||
assert(lhs->getType() == rhs->getType());
|
||||
result->addOperands({lhs, rhs});
|
||||
result->types.push_back(lhs->getType());
|
||||
|
@ -613,8 +612,8 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
|||
// CastOp implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void impl::buildCastOp(Builder *builder, OperationState *result,
|
||||
SSAValue *source, Type destType) {
|
||||
void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
|
||||
Type destType) {
|
||||
result->addOperands(source);
|
||||
result->addTypes(destType);
|
||||
}
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
using namespace mlir;
|
||||
|
||||
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
|
||||
|
@ -77,8 +77,8 @@ PatternRewriter::~PatternRewriter() {
|
|||
/// clients can specify a list of other nodes that this replacement may make
|
||||
/// (perhaps transitively) dead. If any of those ops are dead, this will
|
||||
/// remove them as well.
|
||||
void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead) {
|
||||
void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead) {
|
||||
// Notify the rewriter subclass that we're about to replace this root.
|
||||
notifyRootReplaced(op);
|
||||
|
||||
|
@ -97,15 +97,14 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
|
|||
/// op and newOp are known to have the same number of results, replace the
|
||||
/// uses of op with uses of newOp
|
||||
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
|
||||
Operation *op, Operation *newOp,
|
||||
ArrayRef<SSAValue *> valuesToRemoveIfDead) {
|
||||
Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
|
||||
assert(op->getNumResults() == newOp->getNumResults() &&
|
||||
"replacement op doesn't match results of original op");
|
||||
if (op->getNumResults() == 1)
|
||||
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
|
||||
|
||||
SmallVector<SSAValue *, 8> newResults(newOp->getResults().begin(),
|
||||
newOp->getResults().end());
|
||||
SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
|
||||
newOp->getResults().end());
|
||||
return replaceOp(op, newResults, valuesToRemoveIfDead);
|
||||
}
|
||||
|
||||
|
@ -118,7 +117,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(
|
|||
/// should remove if they are dead at this point.
|
||||
///
|
||||
void PatternRewriter::updatedRootInPlace(
|
||||
Operation *op, ArrayRef<SSAValue *> valuesToRemoveIfDead) {
|
||||
Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
|
||||
// Notify the rewriter subclass that we're about to replace this root.
|
||||
notifyRootUpdated(op);
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- SSAValue.cpp - MLIR SSAValue Classes ------------===//
|
||||
//===- SSAValue.cpp - MLIR ValueClasses ------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,15 +15,15 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// If this value is the result of an Instruction, return the instruction
|
||||
/// that defines it.
|
||||
OperationInst *SSAValue::getDefiningInst() {
|
||||
OperationInst *Value::getDefiningInst() {
|
||||
if (auto *result = dyn_cast<InstResult>(this))
|
||||
return result->getOwner();
|
||||
return nullptr;
|
||||
|
@ -31,13 +31,13 @@ OperationInst *SSAValue::getDefiningInst() {
|
|||
|
||||
/// If this value is the result of an OperationStmt, return the statement
|
||||
/// that defines it.
|
||||
OperationStmt *SSAValue::getDefiningStmt() {
|
||||
OperationStmt *Value::getDefiningStmt() {
|
||||
if (auto *result = dyn_cast<StmtResult>(this))
|
||||
return result->getOwner();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Operation *SSAValue::getDefiningOperation() {
|
||||
Operation *Value::getDefiningOperation() {
|
||||
if (auto *inst = getDefiningInst())
|
||||
return inst;
|
||||
if (auto *stmt = getDefiningStmt())
|
||||
|
@ -45,14 +45,14 @@ Operation *SSAValue::getDefiningOperation() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Return the function that this SSAValue is defined in.
|
||||
Function *SSAValue::getFunction() {
|
||||
/// Return the function that this Valueis defined in.
|
||||
Function *Value::getFunction() {
|
||||
switch (getKind()) {
|
||||
case SSAValueKind::BlockArgument:
|
||||
case Value::Kind::BlockArgument:
|
||||
return cast<BlockArgument>(this)->getFunction();
|
||||
case SSAValueKind::StmtResult:
|
||||
case Value::Kind::StmtResult:
|
||||
return getDefiningStmt()->getFunction();
|
||||
case SSAValueKind::ForStmt:
|
||||
case Value::Kind::ForStmt:
|
||||
return cast<ForStmt>(this)->getFunction();
|
||||
}
|
||||
}
|
||||
|
@ -89,15 +89,6 @@ MLIRContext *IROperandOwner::getContext() const {
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MLValue implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this MLValue is defined in.
|
||||
MLFunction *MLValue::getFunction() {
|
||||
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BlockArgument implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -85,18 +85,16 @@ MLFunction *Statement::getFunction() const {
|
|||
return block ? block->getFunction() : nullptr;
|
||||
}
|
||||
|
||||
MLValue *Statement::getOperand(unsigned idx) {
|
||||
Value *Statement::getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
|
||||
|
||||
const Value *Statement::getOperand(unsigned idx) const {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
|
||||
const MLValue *Statement::getOperand(unsigned idx) const {
|
||||
return getStmtOperand(idx).get();
|
||||
}
|
||||
|
||||
// MLValue can be used as a dimension id if it is valid as a symbol, or
|
||||
// Value can be used as a dimension id if it is valid as a symbol, or
|
||||
// it is an induction variable, or it is a result of affine apply operation
|
||||
// with dimension id arguments.
|
||||
bool MLValue::isValidDim() const {
|
||||
bool Value::isValidDim() const {
|
||||
if (auto *stmt = getDefiningStmt()) {
|
||||
// Top level statement or constant operation is ok.
|
||||
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
|
||||
|
@ -111,10 +109,10 @@ bool MLValue::isValidDim() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
// MLValue can be used as a symbol if it is a constant, or it is defined at
|
||||
// Value can be used as a symbol if it is a constant, or it is defined at
|
||||
// the top level, or it is a result of affine apply operation with symbol
|
||||
// arguments.
|
||||
bool MLValue::isValidSymbol() const {
|
||||
bool Value::isValidSymbol() const {
|
||||
if (auto *stmt = getDefiningStmt()) {
|
||||
// Top level statement or constant operation is ok.
|
||||
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
|
||||
|
@ -129,7 +127,7 @@ bool MLValue::isValidSymbol() const {
|
|||
return isa<BlockArgument>(this);
|
||||
}
|
||||
|
||||
void Statement::setOperand(unsigned idx, MLValue *value) {
|
||||
void Statement::setOperand(unsigned idx, Value *value) {
|
||||
getStmtOperand(idx).set(value);
|
||||
}
|
||||
|
||||
|
@ -271,7 +269,7 @@ void Statement::dropAllReferences() {
|
|||
|
||||
/// Create a new OperationStmt with the specific fields.
|
||||
OperationStmt *OperationStmt::create(Location location, OperationName name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<Type> resultTypes,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
ArrayRef<StmtBlock *> successors,
|
||||
|
@ -420,8 +418,8 @@ void OperationInst::eraseOperand(unsigned index) {
|
|||
// ForStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
|
||||
ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
assert(lbOperands.size() == lbMap.getNumInputs() &&
|
||||
"lower bound operand count does not match the affine map");
|
||||
|
@ -444,9 +442,9 @@ ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands,
|
|||
|
||||
ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step)
|
||||
: Statement(Kind::For, location),
|
||||
MLValue(MLValueKind::ForStmt,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
: Statement(Statement::Kind::For, location),
|
||||
Value(Value::Kind::ForStmt,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
|
||||
// The body of a for stmt always has one block.
|
||||
|
@ -462,11 +460,11 @@ const AffineBound ForStmt::getUpperBound() const {
|
|||
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
|
||||
}
|
||||
|
||||
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
|
||||
void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
|
||||
assert(lbOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(getUpperBoundOperands());
|
||||
|
||||
operands.clear();
|
||||
operands.reserve(lbOperands.size() + ubMap.getNumInputs());
|
||||
|
@ -479,11 +477,11 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
|
|||
this->lbMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) {
|
||||
void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
|
||||
assert(ubOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands());
|
||||
SmallVector<Value *, 4> lbOperands(getLowerBoundOperands());
|
||||
|
||||
operands.clear();
|
||||
operands.reserve(lbOperands.size() + ubOperands.size());
|
||||
|
@ -553,7 +551,7 @@ bool ForStmt::matchingBoundOperandList() const {
|
|||
|
||||
unsigned numOperands = lbMap.getNumInputs();
|
||||
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
|
||||
// Compare MLValue *'s.
|
||||
// Compare Value *'s.
|
||||
if (getOperand(i) != getOperand(numOperands + i))
|
||||
return false;
|
||||
}
|
||||
|
@ -581,7 +579,7 @@ IfStmt::~IfStmt() {
|
|||
// allocated through MLIRContext's bump pointer allocator.
|
||||
}
|
||||
|
||||
IfStmt *IfStmt::create(Location location, ArrayRef<MLValue *> operands,
|
||||
IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
unsigned numOperands = operands.size();
|
||||
assert(numOperands == set.getNumOperands() &&
|
||||
|
@ -617,16 +615,16 @@ MLIRContext *IfStmt::getContext() const {
|
|||
/// them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
||||
Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
|
||||
MLIRContext *context) const {
|
||||
// If the specified value is in operandMap, return the remapped value.
|
||||
// Otherwise return the value itself.
|
||||
auto remapOperand = [&](const MLValue *value) -> MLValue * {
|
||||
auto remapOperand = [&](const Value *value) -> Value * {
|
||||
auto it = operandMap.find(value);
|
||||
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
|
||||
return it != operandMap.end() ? it->second : const_cast<Value *>(value);
|
||||
};
|
||||
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
SmallVector<Value *, 8> operands;
|
||||
SmallVector<StmtBlock *, 2> successors;
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
|
||||
operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
|
||||
|
@ -683,10 +681,9 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
|||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
|
||||
auto *newFor = ForStmt::create(
|
||||
getLoc(),
|
||||
ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap,
|
||||
ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap,
|
||||
forStmt->getStep());
|
||||
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
|
||||
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
|
||||
ubMap, forStmt->getStep());
|
||||
|
||||
// Remember the induction variable mapping.
|
||||
operandMap[forStmt] = newFor;
|
||||
|
@ -716,6 +713,6 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
|
|||
}
|
||||
|
||||
Statement *Statement::clone(MLIRContext *context) const {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
return clone(operandMap, context);
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@
|
|||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::MemoryBuffer;
|
||||
using llvm::SMLoc;
|
||||
|
@ -1890,10 +1889,10 @@ public:
|
|||
|
||||
/// Given a reference to an SSA value and its type, return a reference. This
|
||||
/// returns null on failure.
|
||||
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
|
||||
Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
|
||||
|
||||
/// Register a definition of a value with the symbol table.
|
||||
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
|
||||
ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
|
||||
|
||||
// SSA parsing productions.
|
||||
ParseResult parseSSAUse(SSAUseInfo &result);
|
||||
|
@ -1903,9 +1902,9 @@ public:
|
|||
ResultType parseSSADefOrUseAndType(
|
||||
const std::function<ResultType(SSAUseInfo, Type)> &action);
|
||||
|
||||
SSAValue *parseSSAUseAndType() {
|
||||
return parseSSADefOrUseAndType<SSAValue *>(
|
||||
[&](SSAUseInfo useInfo, Type type) -> SSAValue * {
|
||||
Value *parseSSAUseAndType() {
|
||||
return parseSSADefOrUseAndType<Value *>(
|
||||
[&](SSAUseInfo useInfo, Type type) -> Value * {
|
||||
return resolveSSAUse(useInfo, type);
|
||||
});
|
||||
}
|
||||
|
@ -1920,9 +1919,8 @@ public:
|
|||
Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
|
||||
|
||||
/// Parse a single operation successor and it's operand list.
|
||||
virtual bool
|
||||
parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<SSAValue *> &operands) = 0;
|
||||
virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<Value *> &operands) = 0;
|
||||
|
||||
protected:
|
||||
FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {}
|
||||
|
@ -1934,24 +1932,23 @@ private:
|
|||
Kind kind;
|
||||
/// This keeps track of all of the SSA values we are tracking, indexed by
|
||||
/// their name. This has one entry per result number.
|
||||
llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values;
|
||||
llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
|
||||
|
||||
/// These are all of the placeholders we've made along with the location of
|
||||
/// their first reference, to allow checking for use of undefined values.
|
||||
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
|
||||
DenseMap<Value *, SMLoc> forwardReferencePlaceholders;
|
||||
|
||||
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
|
||||
Value *createForwardReferencePlaceholder(SMLoc loc, Type type);
|
||||
|
||||
/// Return true if this is a forward reference.
|
||||
bool isForwardReferencePlaceholder(SSAValue *value) {
|
||||
bool isForwardReferencePlaceholder(Value *value) {
|
||||
return forwardReferencePlaceholders.count(value);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Create and remember a new placeholder for a forward reference.
|
||||
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
||||
Type type) {
|
||||
Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) {
|
||||
// Forward references are always created as instructions, even in ML
|
||||
// functions, because we just need something with a def/use chain.
|
||||
//
|
||||
|
@ -1969,7 +1966,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
|
|||
|
||||
/// Given an unbound reference to an SSA value and its type, return the value
|
||||
/// it specifies. This returns null on failure.
|
||||
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
|
||||
Value *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
|
||||
auto &entries = values[useInfo.name];
|
||||
|
||||
// If we have already seen a value of this name, return it.
|
||||
|
@ -2010,7 +2007,7 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
|
|||
}
|
||||
|
||||
/// Register a definition of a value with the symbol table.
|
||||
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) {
|
||||
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) {
|
||||
auto &entries = values[useInfo.name];
|
||||
|
||||
// Make sure there is a slot for this value.
|
||||
|
@ -2046,7 +2043,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
|
|||
// Check for any forward references that are left. If we find any, error
|
||||
// out.
|
||||
if (!forwardReferencePlaceholders.empty()) {
|
||||
SmallVector<std::pair<const char *, SSAValue *>, 4> errors;
|
||||
SmallVector<std::pair<const char *, Value *>, 4> errors;
|
||||
// Iteration over the map isn't deterministic, so sort by source location.
|
||||
for (auto entry : forwardReferencePlaceholders)
|
||||
errors.push_back({entry.second.getPointer(), entry.first});
|
||||
|
@ -2399,9 +2396,8 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
bool
|
||||
parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<SSAValue *> &operands) override {
|
||||
bool parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<Value *> &operands) override {
|
||||
// Defer successor parsing to the function parsers.
|
||||
return parser.parseSuccessorAndUseList(dest, operands);
|
||||
}
|
||||
|
@ -2493,7 +2489,7 @@ public:
|
|||
llvm::SMLoc getNameLoc() const override { return nameLoc; }
|
||||
|
||||
bool resolveOperand(const OperandType &operand, Type type,
|
||||
SmallVectorImpl<SSAValue *> &result) override {
|
||||
SmallVectorImpl<Value *> &result) override {
|
||||
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
|
||||
operand.location};
|
||||
if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
|
||||
|
@ -2573,7 +2569,7 @@ public:
|
|||
ParseResult parseFunctionBody();
|
||||
|
||||
bool parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<SSAValue *> &operands);
|
||||
SmallVectorImpl<Value *> &operands);
|
||||
|
||||
private:
|
||||
CFGFunction *function;
|
||||
|
@ -2636,7 +2632,7 @@ private:
|
|||
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
|
||||
///
|
||||
bool CFGFunctionParser::parseSuccessorAndUseList(
|
||||
BasicBlock *&dest, SmallVectorImpl<SSAValue *> &operands) {
|
||||
BasicBlock *&dest, SmallVectorImpl<Value *> &operands) {
|
||||
// Verify branch is identifier and get the matching block.
|
||||
if (!getToken().is(Token::bare_identifier))
|
||||
return emitError("expected basic block name");
|
||||
|
@ -2790,10 +2786,10 @@ private:
|
|||
|
||||
ParseResult parseForStmt();
|
||||
ParseResult parseIntConstant(int64_t &val);
|
||||
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
||||
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
|
||||
unsigned numDims, unsigned numOperands,
|
||||
const char *affineStructName);
|
||||
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
|
||||
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
|
||||
bool isLower);
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseElseClause(StmtBlock *elseClause);
|
||||
|
@ -2801,7 +2797,7 @@ private:
|
|||
ParseResult parseStmtBlock(StmtBlock *block);
|
||||
|
||||
bool parseSuccessorAndUseList(BasicBlock *&dest,
|
||||
SmallVectorImpl<SSAValue *> &operands) {
|
||||
SmallVectorImpl<Value *> &operands) {
|
||||
assert(false && "MLFunctions do not have terminators with successors.");
|
||||
return true;
|
||||
}
|
||||
|
@ -2838,7 +2834,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
return ParseFailure;
|
||||
|
||||
// Parse lower bound.
|
||||
SmallVector<MLValue *, 4> lbOperands;
|
||||
SmallVector<Value *, 4> lbOperands;
|
||||
AffineMap lbMap;
|
||||
if (parseBound(lbOperands, lbMap, /*isLower*/ true))
|
||||
return ParseFailure;
|
||||
|
@ -2847,7 +2843,7 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
return ParseFailure;
|
||||
|
||||
// Parse upper bound.
|
||||
SmallVector<MLValue *, 4> ubOperands;
|
||||
SmallVector<Value *, 4> ubOperands;
|
||||
AffineMap ubMap;
|
||||
if (parseBound(ubOperands, ubMap, /*isLower*/ false))
|
||||
return ParseFailure;
|
||||
|
@ -2913,7 +2909,7 @@ ParseResult MLFunctionParser::parseIntConstant(int64_t &val) {
|
|||
/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list?
|
||||
///
|
||||
ParseResult
|
||||
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
||||
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
|
||||
unsigned numDims, unsigned numOperands,
|
||||
const char *affineStructName) {
|
||||
if (parseToken(Token::l_paren, "expected '('"))
|
||||
|
@ -2942,18 +2938,17 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
|||
// Resolve SSA uses.
|
||||
Type indexType = builder.getIndexType();
|
||||
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
|
||||
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
|
||||
Value *sval = resolveSSAUse(opInfo[i], indexType);
|
||||
if (!sval)
|
||||
return ParseFailure;
|
||||
|
||||
auto *v = cast<MLValue>(sval);
|
||||
if (i < numDims && !v->isValidDim())
|
||||
if (i < numDims && !sval->isValidDim())
|
||||
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
|
||||
"' cannot be used as a dimension id");
|
||||
if (i >= numDims && !v->isValidSymbol())
|
||||
if (i >= numDims && !sval->isValidSymbol())
|
||||
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
|
||||
"' cannot be used as a symbol");
|
||||
operands.push_back(v);
|
||||
operands.push_back(sval);
|
||||
}
|
||||
|
||||
return ParseSuccess;
|
||||
|
@ -2965,7 +2960,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
|
|||
/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
|
||||
/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
|
||||
///
|
||||
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
|
||||
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
|
||||
AffineMap &map, bool isLower) {
|
||||
// 'min' / 'max' prefixes are syntactic sugar. Ignore them.
|
||||
if (isLower)
|
||||
|
@ -3003,7 +2998,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
|
|||
// TODO: improve error message when SSA value is not an affine integer.
|
||||
// Currently it is 'use of value ... expects different type than prior uses'
|
||||
if (auto *value = resolveSSAUse(opInfo, builder.getIndexType()))
|
||||
operands.push_back(cast<MLValue>(value));
|
||||
operands.push_back(value);
|
||||
else
|
||||
return ParseFailure;
|
||||
|
||||
|
@ -3113,7 +3108,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
if (!set)
|
||||
return ParseFailure;
|
||||
|
||||
SmallVector<MLValue *, 4> operands;
|
||||
SmallVector<Value *, 4> operands;
|
||||
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
|
||||
"integer set"))
|
||||
return ParseFailure;
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
@ -78,8 +78,8 @@ struct MemRefCastFolder : public RewritePattern {
|
|||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs) {
|
||||
void AddFOp::build(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs) {
|
||||
assert(lhs->getType() == rhs->getType());
|
||||
result->addOperands({lhs, rhs});
|
||||
result->types.push_back(lhs->getType());
|
||||
|
@ -146,7 +146,7 @@ void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AllocOp::build(Builder *builder, OperationState *result,
|
||||
MemRefType memrefType, ArrayRef<SSAValue *> operands) {
|
||||
MemRefType memrefType, ArrayRef<Value *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->types.push_back(memrefType);
|
||||
}
|
||||
|
@ -247,8 +247,8 @@ struct SimplifyAllocConst : public RewritePattern {
|
|||
// and keep track of the resultant memref type to build.
|
||||
SmallVector<int, 4> newShapeConstants;
|
||||
newShapeConstants.reserve(memrefType.getRank());
|
||||
SmallVector<SSAValue *, 4> newOperands;
|
||||
SmallVector<SSAValue *, 4> droppedOperands;
|
||||
SmallVector<Value *, 4> newOperands;
|
||||
SmallVector<Value *, 4> droppedOperands;
|
||||
|
||||
unsigned dynamicDimPos = 0;
|
||||
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
||||
|
@ -301,7 +301,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
ArrayRef<Value *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType().getResults());
|
||||
|
@ -370,7 +370,7 @@ bool CallOp::verify() const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
||||
Value *callee, ArrayRef<Value *> operands) {
|
||||
auto fnType = callee->getType().cast<FunctionType>();
|
||||
result->operands.push_back(callee);
|
||||
result->addOperands(operands);
|
||||
|
@ -507,7 +507,7 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
|
|||
}
|
||||
|
||||
void CmpIOp::build(Builder *build, OperationState *result,
|
||||
CmpIPredicate predicate, SSAValue *lhs, SSAValue *rhs) {
|
||||
CmpIPredicate predicate, Value *lhs, Value *rhs) {
|
||||
result->addOperands({lhs, rhs});
|
||||
result->types.push_back(getI1SameShape(build, lhs->getType()));
|
||||
result->addAttribute(getPredicateAttrName(),
|
||||
|
@ -580,8 +580,7 @@ bool CmpIOp::verify() const {
|
|||
// DeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DeallocOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *memref) {
|
||||
void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) {
|
||||
result->addOperands(memref);
|
||||
}
|
||||
|
||||
|
@ -615,7 +614,7 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DimOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *memrefOrTensor, unsigned index) {
|
||||
Value *memrefOrTensor, unsigned index) {
|
||||
result->addOperands(memrefOrTensor);
|
||||
auto type = builder->getIndexType();
|
||||
result->addAttribute("index", builder->getIntegerAttr(type, index));
|
||||
|
@ -689,11 +688,11 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
|||
// ---------------------------------------------------------------------------
|
||||
|
||||
void DmaStartOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
|
||||
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
|
||||
SSAValue *numElements, SSAValue *tagMemRef,
|
||||
ArrayRef<SSAValue *> tagIndices, SSAValue *stride,
|
||||
SSAValue *elementsPerStride) {
|
||||
Value *srcMemRef, ArrayRef<Value *> srcIndices,
|
||||
Value *destMemRef, ArrayRef<Value *> destIndices,
|
||||
Value *numElements, Value *tagMemRef,
|
||||
ArrayRef<Value *> tagIndices, Value *stride,
|
||||
Value *elementsPerStride) {
|
||||
result->addOperands(srcMemRef);
|
||||
result->addOperands(srcIndices);
|
||||
result->addOperands(destMemRef);
|
||||
|
@ -836,8 +835,8 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
// ---------------------------------------------------------------------------
|
||||
|
||||
void DmaWaitOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices,
|
||||
SSAValue *numElements) {
|
||||
Value *tagMemRef, ArrayRef<Value *> tagIndices,
|
||||
Value *numElements) {
|
||||
result->addOperands(tagMemRef);
|
||||
result->addOperands(tagIndices);
|
||||
result->addOperands(numElements);
|
||||
|
@ -896,8 +895,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ExtractElementOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *aggregate,
|
||||
ArrayRef<SSAValue *> indices) {
|
||||
Value *aggregate, ArrayRef<Value *> indices) {
|
||||
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
|
||||
result->addOperands(aggregate);
|
||||
result->addOperands(indices);
|
||||
|
@ -955,8 +953,8 @@ bool ExtractElementOp::verify() const {
|
|||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
|
||||
ArrayRef<SSAValue *> indices) {
|
||||
void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
|
||||
ArrayRef<Value *> indices) {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
result->addOperands(memref);
|
||||
result->addOperands(indices);
|
||||
|
@ -1130,9 +1128,8 @@ void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SelectOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *condition, SSAValue *trueValue,
|
||||
SSAValue *falseValue) {
|
||||
void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
|
||||
Value *trueValue, Value *falseValue) {
|
||||
result->addOperands({condition, trueValue, falseValue});
|
||||
result->addTypes(trueValue->getType());
|
||||
}
|
||||
|
@ -1201,8 +1198,8 @@ Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void StoreOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *valueToStore, SSAValue *memref,
|
||||
ArrayRef<SSAValue *> indices) {
|
||||
Value *valueToStore, Value *memref,
|
||||
ArrayRef<Value *> indices) {
|
||||
result->addOperands(valueToStore);
|
||||
result->addOperands(memref);
|
||||
result->addOperands(indices);
|
||||
|
|
|
@ -72,10 +72,10 @@ static bool verifyPermutationMap(AffineMap permutationMap,
|
|||
}
|
||||
|
||||
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
|
||||
VectorType vectorType, SSAValue *srcMemRef,
|
||||
ArrayRef<SSAValue *> srcIndices,
|
||||
VectorType vectorType, Value *srcMemRef,
|
||||
ArrayRef<Value *> srcIndices,
|
||||
AffineMap permutationMap,
|
||||
Optional<SSAValue *> paddingValue) {
|
||||
Optional<Value *> paddingValue) {
|
||||
result->addOperands(srcMemRef);
|
||||
result->addOperands(srcIndices);
|
||||
if (paddingValue) {
|
||||
|
@ -100,21 +100,20 @@ VectorTransferReadOp::getIndices() const {
|
|||
return {begin, end};
|
||||
}
|
||||
|
||||
Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
|
||||
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
|
||||
auto memRefRank = getMemRefType().getRank();
|
||||
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
||||
return None;
|
||||
}
|
||||
return Optional<SSAValue *>(
|
||||
getOperand(Offsets::FirstIndexOffset + memRefRank));
|
||||
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
|
||||
}
|
||||
|
||||
Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
|
||||
Optional<const Value *> VectorTransferReadOp::getPaddingValue() const {
|
||||
auto memRefRank = getMemRefType().getRank();
|
||||
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
|
||||
return None;
|
||||
}
|
||||
return Optional<const SSAValue *>(
|
||||
return Optional<const Value *>(
|
||||
getOperand(Offsets::FirstIndexOffset + memRefRank));
|
||||
}
|
||||
|
||||
|
@ -136,7 +135,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) const {
|
|||
// Construct the FunctionType and print it.
|
||||
llvm::SmallVector<Type, 8> inputs{getMemRefType()};
|
||||
// Must have at least one actual index, see verify.
|
||||
const SSAValue *firstIndex = *(getIndices().begin());
|
||||
const Value *firstIndex = *(getIndices().begin());
|
||||
Type indexType = firstIndex->getType();
|
||||
inputs.append(getMemRefType().getRank(), indexType);
|
||||
if (optionalPaddingValue) {
|
||||
|
@ -295,8 +294,8 @@ bool VectorTransferReadOp::verify() const {
|
|||
// VectorTransferWriteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcVector, SSAValue *dstMemRef,
|
||||
ArrayRef<SSAValue *> dstIndices,
|
||||
Value *srcVector, Value *dstMemRef,
|
||||
ArrayRef<Value *> dstIndices,
|
||||
AffineMap permutationMap) {
|
||||
result->addOperands({srcVector, dstMemRef});
|
||||
result->addOperands(dstIndices);
|
||||
|
@ -457,7 +456,7 @@ bool VectorTransferWriteOp::verify() const {
|
|||
// VectorTypeCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *srcVector, Type dstType) {
|
||||
Value *srcVector, Type dstType) {
|
||||
result->addOperands(srcVector);
|
||||
result->addTypes(dstType);
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ private:
|
|||
/// descriptor and get the pointer to the element indexed by the linearized
|
||||
/// subscript. Return nullptr on errors.
|
||||
llvm::Value *emitMemRefElementAccess(
|
||||
const SSAValue *memRef, const Operation &op,
|
||||
const Value *memRef, const Operation &op,
|
||||
llvm::iterator_range<Operation::const_operand_iterator> opIndices);
|
||||
|
||||
/// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create
|
||||
|
@ -136,12 +136,12 @@ private:
|
|||
|
||||
/// Create a single LLVM value of struct type that includes the list of
|
||||
/// given MLIR values. The `values` list must contain at least 2 elements.
|
||||
llvm::Value *packValues(ArrayRef<const SSAValue *> values);
|
||||
llvm::Value *packValues(ArrayRef<const Value *> values);
|
||||
/// Extract a list of `num` LLVM values from a `value` of struct type.
|
||||
SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num);
|
||||
|
||||
llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
|
||||
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
|
||||
llvm::DenseMap<const Value *, llvm::Value *> valueMapping;
|
||||
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
|
||||
llvm::LLVMContext &llvmContext;
|
||||
llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder;
|
||||
|
@ -316,7 +316,7 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
|
|||
}
|
||||
|
||||
llvm::Value *ModuleLowerer::emitMemRefElementAccess(
|
||||
const SSAValue *memRef, const Operation &op,
|
||||
const Value *memRef, const Operation &op,
|
||||
llvm::iterator_range<Operation::const_operand_iterator> opIndices) {
|
||||
auto type = memRef->getType().dyn_cast<MemRefType>();
|
||||
assert(type && "expected memRef value to have a MemRef type");
|
||||
|
@ -340,7 +340,7 @@ llvm::Value *ModuleLowerer::emitMemRefElementAccess(
|
|||
// Obtain the list of access subscripts as values and linearize it given the
|
||||
// list of sizes.
|
||||
auto indices = functional::map(
|
||||
[this](const SSAValue *value) { return valueMapping.lookup(value); },
|
||||
[this](const Value *value) { return valueMapping.lookup(value); },
|
||||
opIndices);
|
||||
auto subscript = linearizeSubscripts(indices, sizes);
|
||||
|
||||
|
@ -460,11 +460,11 @@ llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) {
|
|||
}
|
||||
|
||||
// Create an undef struct value and insert individual values into it.
|
||||
llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
|
||||
llvm::Value *ModuleLowerer::packValues(ArrayRef<const Value *> values) {
|
||||
assert(values.size() > 1 && "cannot pack less than 2 values");
|
||||
|
||||
auto types =
|
||||
functional::map([](const SSAValue *v) { return v->getType(); }, values);
|
||||
functional::map([](const Value *v) { return v->getType(); }, values);
|
||||
llvm::Type *packedType = getPackedResultType(types);
|
||||
|
||||
llvm::Value *packed = llvm::UndefValue::get(packedType);
|
||||
|
@ -641,7 +641,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
|
|||
return false;
|
||||
}
|
||||
if (auto dimOp = inst.dyn_cast<DimOp>()) {
|
||||
const SSAValue *container = dimOp->getOperand();
|
||||
const Value *container = dimOp->getOperand();
|
||||
MemRefType type = container->getType().dyn_cast<MemRefType>();
|
||||
if (!type)
|
||||
return dimOp->emitError("only memref types are supported");
|
||||
|
@ -672,7 +672,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
|
|||
|
||||
if (auto callOp = inst.dyn_cast<CallOp>()) {
|
||||
auto operands = functional::map(
|
||||
[this](const SSAValue *value) { return valueMapping.lookup(value); },
|
||||
[this](const Value *value) { return valueMapping.lookup(value); },
|
||||
callOp->getOperands());
|
||||
auto numResults = callOp->getNumResults();
|
||||
llvm::Value *result =
|
||||
|
@ -779,10 +779,9 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
|
|||
|
||||
// Get the SSA value passed to the current block from the terminator instruction
|
||||
// of its predecessor.
|
||||
static const SSAValue *getPHISourceValue(const BasicBlock *current,
|
||||
const BasicBlock *pred,
|
||||
unsigned numArguments,
|
||||
unsigned index) {
|
||||
static const Value *getPHISourceValue(const BasicBlock *current,
|
||||
const BasicBlock *pred,
|
||||
unsigned numArguments, unsigned index) {
|
||||
auto &terminator = *pred->getTerminator();
|
||||
if (terminator.isa<BranchOp>()) {
|
||||
return terminator.getOperand(index);
|
||||
|
|
|
@ -30,13 +30,12 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
|||
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
|
||||
|
||||
// All constants in the function post folding.
|
||||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
SmallVector<Value *, 8> existingConstants;
|
||||
// Operation statements that were folded and that need to be erased.
|
||||
std::vector<OperationStmt *> opStmtsToErase;
|
||||
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
|
||||
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
|
||||
|
||||
bool foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants,
|
||||
ConstantFactoryType constantFactory);
|
||||
void visitOperationStmt(OperationStmt *stmt);
|
||||
void visitForStmt(ForStmt *stmt);
|
||||
|
@ -54,9 +53,8 @@ char ConstantFold::passID = 0;
|
|||
///
|
||||
/// This returns false if the operation was successfully folded.
|
||||
bool ConstantFold::foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
SmallVectorImpl<Value *> &existingConstants,
|
||||
ConstantFactoryType constantFactory) {
|
||||
|
||||
// If this operation is already a constant, just remember it for cleanup
|
||||
// later, and don't try to fold it.
|
||||
if (auto constant = op->dyn_cast<ConstantOp>()) {
|
||||
|
@ -114,7 +112,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
if (!inst)
|
||||
continue;
|
||||
|
||||
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> Value * {
|
||||
builder.setInsertionPoint(inst);
|
||||
return builder.create<ConstantOp>(inst->getLoc(), value, type);
|
||||
};
|
||||
|
@ -142,7 +140,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
|
||||
// Override the walker's operation statement visit for constant folding.
|
||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> Value * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||
};
|
||||
|
|
|
@ -50,28 +50,28 @@ public:
|
|||
void visitOperationStmt(OperationStmt *opStmt);
|
||||
|
||||
private:
|
||||
CFGValue *getConstantIndexValue(int64_t value);
|
||||
Value *getConstantIndexValue(int64_t value);
|
||||
void visitStmtBlock(StmtBlock *stmtBlock);
|
||||
CFGValue *buildMinMaxReductionSeq(
|
||||
Value *buildMinMaxReductionSeq(
|
||||
Location loc, CmpIPredicate predicate,
|
||||
llvm::iterator_range<Operation::result_iterator> values);
|
||||
|
||||
CFGFunction *cfgFunc;
|
||||
CFGFuncBuilder builder;
|
||||
|
||||
// Mapping between original MLValues and lowered CFGValues.
|
||||
llvm::DenseMap<const MLValue *, CFGValue *> valueRemapping;
|
||||
// Mapping between original Values and lowered Values.
|
||||
llvm::DenseMap<const Value *, Value *> valueRemapping;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Return a vector of OperationStmt's arguments as SSAValues. For each
|
||||
// statement operands, represented as MLValue, lookup its CFGValue conterpart in
|
||||
// Return a vector of OperationStmt's arguments as Values. For each
|
||||
// statement operands, represented as Value, lookup its Value conterpart in
|
||||
// the valueRemapping table.
|
||||
static llvm::SmallVector<SSAValue *, 4>
|
||||
static llvm::SmallVector<mlir::Value *, 4>
|
||||
operandsAs(Statement *opStmt,
|
||||
const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) {
|
||||
llvm::SmallVector<SSAValue *, 4> operands;
|
||||
for (const MLValue *operand : opStmt->getOperands()) {
|
||||
const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
|
||||
llvm::SmallVector<Value *, 4> operands;
|
||||
for (const Value *operand : opStmt->getOperands()) {
|
||||
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
|
||||
operands.push_back(valueRemapping.lookup(operand));
|
||||
}
|
||||
|
@ -81,8 +81,8 @@ operandsAs(Statement *opStmt,
|
|||
// Convert an operation statement into an operation instruction.
|
||||
//
|
||||
// The operation description (name, number and types of operands or results)
|
||||
// remains the same but the values must be updated to be CFGValues. Update the
|
||||
// mapping MLValue->CFGValue as the conversion is performed. The operation
|
||||
// remains the same but the values must be updated to be Values. Update the
|
||||
// mapping Value->Value as the conversion is performed. The operation
|
||||
// instruction is appended to current block (end of SESE region).
|
||||
void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
|
||||
// Set up basic operation state (context, name, operands).
|
||||
|
@ -90,11 +90,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
|
|||
opStmt->getName());
|
||||
state.addOperands(operandsAs(opStmt, valueRemapping));
|
||||
|
||||
// Set up operation return types. The corresponding SSAValues will become
|
||||
// Set up operation return types. The corresponding Values will become
|
||||
// available after the operation is created.
|
||||
state.addTypes(
|
||||
functional::map([](SSAValue *result) { return result->getType(); },
|
||||
opStmt->getResults()));
|
||||
state.addTypes(functional::map(
|
||||
[](Value *result) { return result->getType(); }, opStmt->getResults()));
|
||||
|
||||
// Copy attributes.
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
|
@ -112,10 +111,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
|
|||
}
|
||||
}
|
||||
|
||||
// Create a CFGValue for the given integer constant of index type.
|
||||
CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) {
|
||||
// Create a Value for the given integer constant of index type.
|
||||
Value *FunctionConverter::getConstantIndexValue(int64_t value) {
|
||||
auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value);
|
||||
return cast<CFGValue>(op->getResult());
|
||||
return op->getResult();
|
||||
}
|
||||
|
||||
// Visit all statements in the given statement block.
|
||||
|
@ -135,18 +134,18 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) {
|
|||
// Multiple values are scanned in a linear sequence. This creates a data
|
||||
// dependences that wouldn't exist in a tree reduction, but is easier to
|
||||
// recognize as a reduction by the subsequent passes.
|
||||
CFGValue *FunctionConverter::buildMinMaxReductionSeq(
|
||||
Value *FunctionConverter::buildMinMaxReductionSeq(
|
||||
Location loc, CmpIPredicate predicate,
|
||||
llvm::iterator_range<Operation::result_iterator> values) {
|
||||
assert(!llvm::empty(values) && "empty min/max chain");
|
||||
|
||||
auto valueIt = values.begin();
|
||||
CFGValue *value = cast<CFGValue>(*valueIt++);
|
||||
Value *value = *valueIt++;
|
||||
for (; valueIt != values.end(); ++valueIt) {
|
||||
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
|
||||
auto selectOp =
|
||||
builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt);
|
||||
value = cast<CFGValue>(selectOp->getResult());
|
||||
value = selectOp->getResult();
|
||||
}
|
||||
|
||||
return value;
|
||||
|
@ -231,9 +230,9 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
// The loop condition block has an argument for loop induction variable.
|
||||
// Create it upfront and make the loop induction variable -> basic block
|
||||
// argument remapping available to the following instructions. ForStatement
|
||||
// is-a MLValue corresponding to the loop induction variable.
|
||||
// is-a Value corresponding to the loop induction variable.
|
||||
builder.setInsertionPoint(loopConditionBlock);
|
||||
CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType());
|
||||
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
|
||||
valueRemapping.insert(std::make_pair(forStmt, iv));
|
||||
|
||||
// Recursively construct loop body region.
|
||||
|
@ -251,7 +250,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
|
||||
auto stepOp =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
|
||||
CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult(0));
|
||||
Value *nextIvValue = stepOp->getResult(0);
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
|
||||
nextIvValue);
|
||||
|
||||
|
@ -260,20 +259,19 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
|
||||
builder.setInsertionPoint(loopInitBlock);
|
||||
// Compute loop bounds using affine_apply after remapping its operands.
|
||||
auto remapOperands = [this](const SSAValue *value) -> SSAValue * {
|
||||
const MLValue *mlValue = dyn_cast<MLValue>(value);
|
||||
return valueRemapping.lookup(mlValue);
|
||||
auto remapOperands = [this](const Value *value) -> Value * {
|
||||
return valueRemapping.lookup(value);
|
||||
};
|
||||
auto operands =
|
||||
functional::map(remapOperands, forStmt->getLowerBoundOperands());
|
||||
auto lbAffineApply = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
|
||||
CFGValue *lowerBound = buildMinMaxReductionSeq(
|
||||
Value *lowerBound = buildMinMaxReductionSeq(
|
||||
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
|
||||
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
|
||||
auto ubAffineApply = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
|
||||
CFGValue *upperBound = buildMinMaxReductionSeq(
|
||||
Value *upperBound = buildMinMaxReductionSeq(
|
||||
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
|
||||
lowerBound);
|
||||
|
@ -281,10 +279,10 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
builder.setInsertionPoint(loopConditionBlock);
|
||||
auto comparisonOp = builder.create<CmpIOp>(
|
||||
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
|
||||
auto comparisonResult = cast<CFGValue>(comparisonOp->getResult());
|
||||
auto comparisonResult = comparisonOp->getResult();
|
||||
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
|
||||
loopBodyFirstBlock, ArrayRef<SSAValue *>(),
|
||||
postLoopBlock, ArrayRef<SSAValue *>());
|
||||
loopBodyFirstBlock, ArrayRef<Value *>(),
|
||||
postLoopBlock, ArrayRef<Value *>());
|
||||
|
||||
// Finally, make sure building can continue by setting the post-loop block
|
||||
// (end of loop SESE region) as the insertion point.
|
||||
|
@ -401,7 +399,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
// If the test succeeds, jump to the next block testing testing the next
|
||||
// conjunct of the condition in the similar way. When all conjuncts have been
|
||||
// handled, jump to the 'then' block instead.
|
||||
SSAValue *zeroConstant = getConstantIndexValue(0);
|
||||
Value *zeroConstant = getConstantIndexValue(0);
|
||||
ifConditionExtraBlocks.push_back(thenBlock);
|
||||
for (auto tuple :
|
||||
llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(),
|
||||
|
@ -416,16 +414,16 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
integerSet.getNumSymbols(), constraintExpr, {});
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
|
||||
SSAValue *affResult = affineApplyOp->getResult(0);
|
||||
Value *affResult = affineApplyOp->getResult(0);
|
||||
|
||||
// Compare the result of the apply and branch.
|
||||
auto comparisonOp = builder.create<CmpIOp>(
|
||||
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
|
||||
affResult, zeroConstant);
|
||||
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
|
||||
nextBlock, /*trueArgs*/ ArrayRef<SSAValue *>(),
|
||||
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
|
||||
elseBlock,
|
||||
/*falseArgs*/ ArrayRef<SSAValue *>());
|
||||
/*falseArgs*/ ArrayRef<Value *>());
|
||||
builder.setInsertionPoint(nextBlock);
|
||||
}
|
||||
ifConditionExtraBlocks.pop_back();
|
||||
|
@ -468,10 +466,10 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
// of the current region. The SESE invariant allows us to easily handle nested
|
||||
// structures of arbitrary complexity.
|
||||
//
|
||||
// During the conversion, we maintain a mapping between the MLValues present in
|
||||
// the original function and their CFGValue images in the function under
|
||||
// construction. When an MLValue is used, it gets replaced with the
|
||||
// corresponding CFGValue that has been defined previously. The value flow
|
||||
// During the conversion, we maintain a mapping between the Values present in
|
||||
// the original function and their Value images in the function under
|
||||
// construction. When an Value is used, it gets replaced with the
|
||||
// corresponding Value that has been defined previously. The value flow
|
||||
// starts with function arguments converted to basic block arguments.
|
||||
CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
|
||||
auto outerBlock = builder.createBlock();
|
||||
|
@ -482,8 +480,8 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
|
|||
outerBlock->addArguments(mlFunc->getType().getInputs());
|
||||
assert(mlFunc->getNumArguments() == outerBlock->getNumArguments());
|
||||
for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) {
|
||||
const MLValue *mlArgument = mlFunc->getArgument(i);
|
||||
CFGValue *cfgArgument = outerBlock->getArgument(i);
|
||||
const Value *mlArgument = mlFunc->getArgument(i);
|
||||
Value *cfgArgument = outerBlock->getArgument(i);
|
||||
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
|
||||
}
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
|
|||
|
||||
// Map from original memref's to the DMA buffers that their accesses are
|
||||
// replaced with.
|
||||
DenseMap<SSAValue *, SSAValue *> fastBufferMap;
|
||||
DenseMap<Value *, Value *> fastBufferMap;
|
||||
|
||||
// Slow memory space associated with DMAs.
|
||||
const unsigned slowMemorySpace;
|
||||
|
@ -195,11 +195,11 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
|
||||
// Indices to use for the DmaStart op.
|
||||
// Indices for the original memref being DMAed from/to.
|
||||
SmallVector<SSAValue *, 4> memIndices;
|
||||
SmallVector<Value *, 4> memIndices;
|
||||
// Indices for the faster buffer being DMAed into/from.
|
||||
SmallVector<SSAValue *, 4> bufIndices;
|
||||
SmallVector<Value *, 4> bufIndices;
|
||||
|
||||
SSAValue *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
|
||||
Value *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
|
||||
|
||||
unsigned rank = memRefType.getRank();
|
||||
SmallVector<int, 4> fastBufferShape;
|
||||
|
@ -226,10 +226,10 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
// DMA generation is being done.
|
||||
const FlatAffineConstraints *cst = region.getConstraints();
|
||||
auto ids = cst->getIds();
|
||||
SmallVector<SSAValue *, 8> outerIVs;
|
||||
SmallVector<Value *, 8> outerIVs;
|
||||
for (unsigned i = rank, e = ids.size(); i < e; i++) {
|
||||
auto id = cst->getIds()[i];
|
||||
assert(id.hasValue() && "MLValue id expected");
|
||||
assert(id.hasValue() && "Value id expected");
|
||||
outerIVs.push_back(id.getValue());
|
||||
}
|
||||
|
||||
|
@ -253,15 +253,15 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
// Set DMA start location for this dimension in the lower memory space
|
||||
// memref.
|
||||
if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
|
||||
memIndices.push_back(cast<MLValue>(
|
||||
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult()));
|
||||
memIndices.push_back(
|
||||
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult());
|
||||
} else {
|
||||
// The coordinate for the start location is just the lower bound along the
|
||||
// corresponding dimension on the memory region (stored in 'offset').
|
||||
auto map = top.getAffineMap(
|
||||
cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {});
|
||||
memIndices.push_back(cast<MLValue>(
|
||||
b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0)));
|
||||
memIndices.push_back(
|
||||
b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0));
|
||||
}
|
||||
// The fast buffer is DMAed into at location zero; addressing is relative.
|
||||
bufIndices.push_back(zeroIndex);
|
||||
|
@ -272,7 +272,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
}
|
||||
|
||||
// The faster memory space buffer.
|
||||
SSAValue *fastMemRef;
|
||||
Value *fastMemRef;
|
||||
|
||||
// Check if a buffer was already created.
|
||||
// TODO(bondhugula): union across all memory op's per buffer. For now assuming
|
||||
|
@ -321,8 +321,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
return false;
|
||||
}
|
||||
|
||||
SSAValue *stride = nullptr;
|
||||
SSAValue *numEltPerStride = nullptr;
|
||||
Value *stride = nullptr;
|
||||
Value *numEltPerStride = nullptr;
|
||||
if (!strideInfos.empty()) {
|
||||
stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
|
||||
numEltPerStride =
|
||||
|
@ -362,7 +362,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
}
|
||||
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
|
||||
// *Only* those uses within the body of 'forStmt' are replaced.
|
||||
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
|
||||
replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*domStmtFilter=*/&*forStmt->getBody()->begin());
|
||||
|
|
|
@ -83,22 +83,22 @@ FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
|||
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
|
||||
MemRefAccess *access) {
|
||||
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
|
||||
access->memref = cast<MLValue>(loadOp->getMemRef());
|
||||
access->memref = loadOp->getMemRef();
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
auto loadMemrefType = loadOp->getMemRefType();
|
||||
access->indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
access->indices.push_back(cast<MLValue>(index));
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
} else {
|
||||
assert(loadOrStoreOpStmt->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
access->memref = cast<MLValue>(storeOp->getMemRef());
|
||||
access->memref = storeOp->getMemRef();
|
||||
auto storeMemrefType = storeOp->getMemRefType();
|
||||
access->indices.reserve(storeMemrefType.getRank());
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
access->indices.push_back(cast<MLValue>(index));
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -178,20 +178,20 @@ public:
|
|||
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
|
||||
|
||||
// Returns the load op count for 'memref'.
|
||||
unsigned getLoadOpCount(MLValue *memref) {
|
||||
unsigned getLoadOpCount(Value *memref) {
|
||||
unsigned loadOpCount = 0;
|
||||
for (auto *loadOpStmt : loads) {
|
||||
if (memref == cast<MLValue>(loadOpStmt->cast<LoadOp>()->getMemRef()))
|
||||
if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
|
||||
++loadOpCount;
|
||||
}
|
||||
return loadOpCount;
|
||||
}
|
||||
|
||||
// Returns the store op count for 'memref'.
|
||||
unsigned getStoreOpCount(MLValue *memref) {
|
||||
unsigned getStoreOpCount(Value *memref) {
|
||||
unsigned storeOpCount = 0;
|
||||
for (auto *storeOpStmt : stores) {
|
||||
if (memref == cast<MLValue>(storeOpStmt->cast<StoreOp>()->getMemRef()))
|
||||
if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
|
||||
++storeOpCount;
|
||||
}
|
||||
return storeOpCount;
|
||||
|
@ -203,7 +203,7 @@ public:
|
|||
// The id of the node at the other end of the edge.
|
||||
unsigned id;
|
||||
// The memref on which this edge represents a dependence.
|
||||
MLValue *memref;
|
||||
Value *memref;
|
||||
};
|
||||
|
||||
// Map from node id to Node.
|
||||
|
@ -227,13 +227,13 @@ public:
|
|||
}
|
||||
|
||||
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||
void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) {
|
||||
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||
outEdges[srcId].push_back({dstId, memref});
|
||||
inEdges[dstId].push_back({srcId, memref});
|
||||
}
|
||||
|
||||
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
|
||||
void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) {
|
||||
void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
|
||||
assert(inEdges.count(dstId) > 0);
|
||||
assert(outEdges.count(srcId) > 0);
|
||||
// Remove 'srcId' from 'inEdges[dstId]'.
|
||||
|
@ -253,7 +253,7 @@ public:
|
|||
}
|
||||
|
||||
// Returns the input edge count for node 'id' and 'memref'.
|
||||
unsigned getInEdgeCount(unsigned id, MLValue *memref) {
|
||||
unsigned getInEdgeCount(unsigned id, Value *memref) {
|
||||
unsigned inEdgeCount = 0;
|
||||
if (inEdges.count(id) > 0)
|
||||
for (auto &inEdge : inEdges[id])
|
||||
|
@ -263,7 +263,7 @@ public:
|
|||
}
|
||||
|
||||
// Returns the output edge count for node 'id' and 'memref'.
|
||||
unsigned getOutEdgeCount(unsigned id, MLValue *memref) {
|
||||
unsigned getOutEdgeCount(unsigned id, Value *memref) {
|
||||
unsigned outEdgeCount = 0;
|
||||
if (outEdges.count(id) > 0)
|
||||
for (auto &outEdge : outEdges[id])
|
||||
|
@ -347,7 +347,7 @@ public:
|
|||
// dependence graph at a different depth.
|
||||
bool MemRefDependenceGraph::init(MLFunction *f) {
|
||||
unsigned id = 0;
|
||||
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses;
|
||||
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
|
||||
// Create graph node 'id' to represent top-level 'forStmt' and record
|
||||
|
@ -360,12 +360,12 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
|
|||
Node node(id++, &stmt);
|
||||
for (auto *opStmt : collector.loadOpStmts) {
|
||||
node.loads.push_back(opStmt);
|
||||
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef());
|
||||
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
}
|
||||
for (auto *opStmt : collector.storeOpStmts) {
|
||||
node.stores.push_back(opStmt);
|
||||
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef());
|
||||
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
}
|
||||
nodes.insert({node.id, node});
|
||||
|
@ -375,7 +375,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
|
|||
// Create graph node for top-level load op.
|
||||
Node node(id++, &stmt);
|
||||
node.loads.push_back(opStmt);
|
||||
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef());
|
||||
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
|
@ -383,7 +383,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
|
|||
// Create graph node for top-level store op.
|
||||
Node node(id++, &stmt);
|
||||
node.stores.push_back(opStmt);
|
||||
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef());
|
||||
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
|
@ -477,8 +477,7 @@ public:
|
|||
SmallVector<OperationStmt *, 4> loads = dstNode->loads;
|
||||
while (!loads.empty()) {
|
||||
auto *dstLoadOpStmt = loads.pop_back_val();
|
||||
auto *memref =
|
||||
cast<MLValue>(dstLoadOpStmt->cast<LoadOp>()->getMemRef());
|
||||
auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
|
||||
// Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
|
||||
if (dstNode->getLoadOpCount(memref) != 1)
|
||||
continue;
|
||||
|
|
|
@ -85,10 +85,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
|||
for (unsigned i = 0; i < width; i++) {
|
||||
auto lbOperands = origLoops[i]->getLowerBoundOperands();
|
||||
auto ubOperands = origLoops[i]->getUpperBoundOperands();
|
||||
SmallVector<MLValue *, 4> newLbOperands(lbOperands.begin(),
|
||||
lbOperands.end());
|
||||
SmallVector<MLValue *, 4> newUbOperands(ubOperands.begin(),
|
||||
ubOperands.end());
|
||||
SmallVector<Value *, 4> newLbOperands(lbOperands.begin(), lbOperands.end());
|
||||
SmallVector<Value *, 4> newUbOperands(ubOperands.begin(), ubOperands.end());
|
||||
newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap());
|
||||
newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap());
|
||||
newLoops[i]->setStep(tileSizes[i]);
|
||||
|
@ -112,8 +110,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
|||
// Construct the upper bound map; the operands are the original operands
|
||||
// with 'i' (tile-space loop) appended to it. The new upper bound map is
|
||||
// the original one with an additional expression i + tileSize appended.
|
||||
SmallVector<MLValue *, 4> ubOperands(
|
||||
origLoops[i]->getUpperBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
|
||||
ubOperands.push_back(newLoops[i]);
|
||||
|
||||
auto origUbMap = origLoops[i]->getUpperBoundMap();
|
||||
|
@ -191,8 +188,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
// Move the loop body of the original nest to the new one.
|
||||
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
|
||||
|
||||
SmallVector<MLValue *, 6> origLoopIVs(band.begin(), band.end());
|
||||
SmallVector<Optional<MLValue *>, 6> ids(band.begin(), band.end());
|
||||
SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
|
||||
SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
|
||||
FlatAffineConstraints cst;
|
||||
getIndexSet(band, &cst);
|
||||
|
||||
|
|
|
@ -191,7 +191,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
// unrollJamFactor.
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
// Insert the cleanup loop right after 'forStmt'.
|
||||
MLFuncBuilder builder(forStmt->getBlock(),
|
||||
std::next(StmtBlock::iterator(forStmt)));
|
||||
|
@ -219,7 +219,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
// Unroll and jam (appends unrollJamFactor-1 additional copies).
|
||||
for (unsigned i = 1; i < unrollJamFactor; i++) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMapping;
|
||||
DenseMap<const Value *, Value *> operandMapping;
|
||||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
|
@ -230,7 +230,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
->getResult(0);
|
||||
operandMapping[forStmt] = cast<MLValue>(ivUnroll);
|
||||
operandMapping[forStmt] = ivUnroll;
|
||||
}
|
||||
// Clone the sub-block being unroll-jammed.
|
||||
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
|
||||
|
|
|
@ -29,17 +29,14 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/MLPatternLoweringPass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -62,26 +59,26 @@ using namespace mlir;
|
|||
|
||||
#define DEBUG_TYPE "lower-vector-transfers"
|
||||
|
||||
/// Creates the SSAValue for the sum of `a` and `b` without building a
|
||||
/// Creates the Value for the sum of `a` and `b` without building a
|
||||
/// full-fledged AffineMap for all indices.
|
||||
///
|
||||
/// Prerequisites:
|
||||
/// `a` and `b` must be of IndexType.
|
||||
static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) {
|
||||
static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) {
|
||||
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
|
||||
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
|
||||
auto *context = b->getContext();
|
||||
auto d0 = getAffineDimExpr(0, context);
|
||||
auto d1 = getAffineDimExpr(1, context);
|
||||
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
|
||||
return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w})
|
||||
return b->create<AffineApplyOp>(loc, map, ArrayRef<mlir::Value *>{v, w})
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowerVectorTransfersState : public MLFuncGlobalLoweringState {
|
||||
// Top of the function constant zero index.
|
||||
SSAValue *zero;
|
||||
Value *zero;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -131,7 +128,8 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// case of GPUs.
|
||||
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
|
||||
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
|
||||
vecView->getResult(), ArrayRef<SSAValue *>{state->zero});
|
||||
vecView->getResult(),
|
||||
ArrayRef<mlir::Value *>{state->zero});
|
||||
}
|
||||
|
||||
// 3. Emit the loop-nest.
|
||||
|
@ -140,7 +138,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// TODO(ntv): Handle broadcast / slice properly.
|
||||
auto permutationMap = transfer->getPermutationMap();
|
||||
SetVector<ForStmt *> loops;
|
||||
SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices());
|
||||
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
|
||||
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
|
||||
auto composed = composeWithUnboundedMap(
|
||||
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
|
||||
|
@ -168,17 +166,16 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// b. write scalar to local.
|
||||
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
|
||||
transfer->getMemRef(), accessIndices);
|
||||
b.create<StoreOp>(
|
||||
transfer->getLoc(), scalarLoad->getResult(),
|
||||
tmpScalarAlloc->getResult(),
|
||||
functional::map([](SSAValue *val) { return val; }, loops));
|
||||
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
|
||||
tmpScalarAlloc->getResult(),
|
||||
functional::map([](Value *val) { return val; }, loops));
|
||||
} else {
|
||||
// VectorTransferWriteOp.
|
||||
// a. read scalar from local;
|
||||
// b. write scalar to remote.
|
||||
auto scalarLoad = b.create<LoadOp>(
|
||||
transfer->getLoc(), tmpScalarAlloc->getResult(),
|
||||
functional::map([](SSAValue *val) { return val; }, loops));
|
||||
functional::map([](Value *val) { return val; }, loops));
|
||||
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
|
||||
transfer->getMemRef(), accessIndices);
|
||||
}
|
||||
|
@ -186,11 +183,11 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// 5. Read the vector from local storage in case of a vector_transfer_read.
|
||||
// TODO(ntv): This vector_load operation should be further lowered in the
|
||||
// case of GPUs.
|
||||
llvm::SmallVector<SSAValue *, 1> newResults = {};
|
||||
llvm::SmallVector<Value *, 1> newResults = {};
|
||||
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
|
||||
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
|
||||
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(),
|
||||
ArrayRef<SSAValue *>{state->zero})
|
||||
ArrayRef<Value *>{state->zero})
|
||||
->getResult();
|
||||
newResults.push_back(vector);
|
||||
}
|
||||
|
|
|
@ -32,9 +32,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
|
@ -192,7 +190,7 @@ struct MaterializationState {
|
|||
VectorType superVectorType;
|
||||
VectorType hwVectorType;
|
||||
SmallVector<unsigned, 8> hwVectorInstance;
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap;
|
||||
DenseMap<const Value *, Value *> *substitutionsMap;
|
||||
};
|
||||
|
||||
struct MaterializeVectorsPass : public FunctionPass {
|
||||
|
@ -250,9 +248,9 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
|
|||
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap);
|
||||
DenseMap<const Value *, Value *> *substitutionsMap);
|
||||
|
||||
/// Not all SSAValue belong to a program slice scoped within the immediately
|
||||
/// Not all Values belong to a program slice scoped within the immediately
|
||||
/// enclosing loop.
|
||||
/// One simple example is constants defined outside the innermost loop scope.
|
||||
/// For such cases the substitutionsMap has no entry and we allow an additional
|
||||
|
@ -261,17 +259,16 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
|||
/// indices and will need to be extended in the future.
|
||||
///
|
||||
/// If substitution fails, returns nullptr.
|
||||
static MLValue *
|
||||
substitute(SSAValue *v, VectorType hwVectorType,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
auto it = substitutionsMap->find(cast<MLValue>(v));
|
||||
static Value *substitute(Value *v, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
auto it = substitutionsMap->find(v);
|
||||
if (it == substitutionsMap->end()) {
|
||||
auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
|
||||
if (opStmt->isa<ConstantOp>()) {
|
||||
MLFuncBuilder b(opStmt);
|
||||
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
|
||||
auto res = substitutionsMap->insert(
|
||||
std::make_pair(cast<MLValue>(v), cast<MLValue>(inst->getResult(0))));
|
||||
auto res =
|
||||
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
|
||||
assert(res.second && "Insertion failed");
|
||||
return res.first->second;
|
||||
}
|
||||
|
@ -336,10 +333,10 @@ substitute(SSAValue *v, VectorType hwVectorType,
|
|||
/// TODO(ntv): support a concrete AffineMap and compose with it.
|
||||
/// TODO(ntv): these implementation details should be captured in a
|
||||
/// vectorization trait at the op level directly.
|
||||
static SmallVector<SSAValue *, 8>
|
||||
static SmallVector<mlir::Value *, 8>
|
||||
reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
|
||||
ArrayRef<unsigned> hwVectorInstance,
|
||||
ArrayRef<SSAValue *> memrefIndices) {
|
||||
ArrayRef<Value *> memrefIndices) {
|
||||
auto vectorShape = hwVectorType.getShape();
|
||||
assert(hwVectorInstance.size() >= vectorShape.size());
|
||||
|
||||
|
@ -380,7 +377,7 @@ reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
|
|||
// TODO(ntv): support a concrete map and composition.
|
||||
auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(),
|
||||
affineMap, memrefIndices);
|
||||
return SmallVector<SSAValue *, 8>{app->getResults()};
|
||||
return SmallVector<mlir::Value *, 8>{app->getResults()};
|
||||
}
|
||||
|
||||
/// Returns attributes with the following substitutions applied:
|
||||
|
@ -402,21 +399,21 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
|
|||
|
||||
/// Creates an instantiated version of `opStmt`.
|
||||
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
|
||||
/// affine reindexing. Just substitute their SSAValue* operands and be done. For
|
||||
/// this case the actual instance is irrelevant. Just use the SSA values in
|
||||
/// affine reindexing. Just substitute their Value operands and be done. For
|
||||
/// this case the actual instance is irrelevant. Just use the values in
|
||||
/// substitutionsMap.
|
||||
///
|
||||
/// If the underlying substitution fails, this fails too and returns nullptr.
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
assert(!opStmt->isa<VectorTransferReadOp>() &&
|
||||
"Should call the function specialized for VectorTransferReadOp");
|
||||
assert(!opStmt->isa<VectorTransferWriteOp>() &&
|
||||
"Should call the function specialized for VectorTransferWriteOp");
|
||||
bool fail = false;
|
||||
auto operands = map(
|
||||
[hwVectorType, substitutionsMap, &fail](SSAValue *v) -> SSAValue * {
|
||||
[hwVectorType, substitutionsMap, &fail](Value *v) -> Value * {
|
||||
auto *res =
|
||||
fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
|
||||
fail |= !res;
|
||||
|
@ -481,9 +478,9 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
|
|||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
|
||||
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
SmallVector<SSAValue *, 8> indices =
|
||||
map(makePtrDynCaster<SSAValue>(), read->getIndices());
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
SmallVector<Value *, 8> indices =
|
||||
map(makePtrDynCaster<Value>(), read->getIndices());
|
||||
auto affineIndices =
|
||||
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
|
||||
auto cloned = b->create<VectorTransferReadOp>(
|
||||
|
@ -501,9 +498,9 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
|
|||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
|
||||
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
SmallVector<SSAValue *, 8> indices =
|
||||
map(makePtrDynCaster<SSAValue>(), write->getIndices());
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
SmallVector<Value *, 8> indices =
|
||||
map(makePtrDynCaster<Value>(), write->getIndices());
|
||||
auto affineIndices =
|
||||
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
|
||||
auto cloned = b->create<VectorTransferWriteOp>(
|
||||
|
@ -555,8 +552,8 @@ static bool instantiateMaterialization(Statement *stmt,
|
|||
} else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
|
||||
auto *clone = instantiate(&b, read, state->hwVectorType,
|
||||
state->hwVectorInstance, state->substitutionsMap);
|
||||
state->substitutionsMap->insert(std::make_pair(
|
||||
cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0))));
|
||||
state->substitutionsMap->insert(
|
||||
std::make_pair(read->getResult(), clone->getResult(0)));
|
||||
return false;
|
||||
}
|
||||
// The only op with 0 results reaching this point must, by construction, be
|
||||
|
@ -571,8 +568,8 @@ static bool instantiateMaterialization(Statement *stmt,
|
|||
if (!clone) {
|
||||
return true;
|
||||
}
|
||||
state->substitutionsMap->insert(std::make_pair(
|
||||
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
|
||||
state->substitutionsMap->insert(
|
||||
std::make_pair(opStmt->getResult(0), clone->getResult(0)));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -610,7 +607,7 @@ static bool emitSlice(MaterializationState *state,
|
|||
// Fresh RAII instanceIndices and substitutionsMap.
|
||||
MaterializationState scopedState = *state;
|
||||
scopedState.hwVectorInstance = delinearize(idx, *ratio);
|
||||
DenseMap<const MLValue *, MLValue *> substitutionMap;
|
||||
DenseMap<const Value *, Value *> substitutionMap;
|
||||
scopedState.substitutionsMap = &substitutionMap;
|
||||
// slice are topologically sorted, we can just clone them in order.
|
||||
for (auto *stmt : *slice) {
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "pipeline-data-transfer"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -80,7 +79,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
|||
/// of the old memref by the new one while indexing the newly added dimension by
|
||||
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
|
||||
/// a replacement cannot be performed.
|
||||
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
||||
auto *forBody = forStmt->getBody();
|
||||
MLFuncBuilder bInner(forBody, forBody->begin());
|
||||
bInner.setInsertionPoint(forBody, forBody->begin());
|
||||
|
@ -103,7 +102,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
|
||||
// Put together alloc operands for the dynamic dimensions of the memref.
|
||||
MLFuncBuilder bOuter(forStmt);
|
||||
SmallVector<SSAValue *, 4> allocOperands;
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
unsigned dynamicDimCount = 0;
|
||||
for (auto dimSize : oldMemRefType.getShape()) {
|
||||
if (dimSize == -1)
|
||||
|
@ -114,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
// Create and place the alloc right before the 'for' statement.
|
||||
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
|
||||
// inserting a dealloc -- this isn't the right thing.
|
||||
SSAValue *newMemRef =
|
||||
Value *newMemRef =
|
||||
bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
|
||||
|
||||
// Create 'iv mod 2' value to index the leading dimension.
|
||||
|
@ -126,8 +125,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
|||
|
||||
// replaceAllMemRefUsesWith will always succeed unless the forStmt body has
|
||||
// non-deferencing uses of the memref.
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
|
||||
ivModTwoOp->getResult(0), AffineMap::Null(), {},
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
|
||||
AffineMap::Null(), {},
|
||||
&*forStmt->getBody()->begin())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "memref replacement for double buffering failed\n";);
|
||||
|
@ -225,8 +224,7 @@ static void findMatchingStartFinishStmts(
|
|||
continue;
|
||||
|
||||
// We only double buffer if the buffer is not live out of loop.
|
||||
const MLValue *memref =
|
||||
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
|
||||
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
|
||||
bool escapingUses = false;
|
||||
for (const auto &use : memref->getUses()) {
|
||||
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
|
||||
|
@ -280,8 +278,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// dimension.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
||||
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()));
|
||||
Value *oldMemRef = dmaStartStmt->getOperand(
|
||||
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
|
||||
if (!doubleBuffer(oldMemRef, forStmt)) {
|
||||
// Normally, double buffering should not fail because we already checked
|
||||
// that there are no uses outside.
|
||||
|
@ -302,8 +300,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// Double the buffers for tag memrefs.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaFinishStmt = pair.second;
|
||||
MLValue *oldTagMemRef = cast<MLValue>(
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
|
||||
Value *oldTagMemRef =
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
|
||||
if (!doubleBuffer(oldTagMemRef, forStmt)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
||||
return success();
|
||||
|
@ -332,7 +330,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// If a slice wasn't created, the reachable affine_apply op's from its
|
||||
// operands are the ones that go with it.
|
||||
SmallVector<OperationStmt *, 4> affineApplyStmts;
|
||||
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
|
||||
SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyStmts);
|
||||
for (const auto *stmt : affineApplyStmts) {
|
||||
stmtShiftMap[stmt] = 0;
|
||||
|
|
|
@ -217,7 +217,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
|
||||
// If we already have a canonicalized version of this constant, just
|
||||
// reuse it. Otherwise create a new one.
|
||||
SSAValue *cstValue;
|
||||
Value *cstValue;
|
||||
auto it = uniquedConstants.find({resultConstants[i], res->getType()});
|
||||
if (it != uniquedConstants.end())
|
||||
cstValue = it->second->getResult(0);
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "LoopUtils"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -108,8 +107,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
|||
forStmt->replaceAllUsesWith(constOp);
|
||||
} else {
|
||||
const AffineBound lb = forStmt->getLowerBound();
|
||||
SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
|
||||
lb.operand_end());
|
||||
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
|
||||
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), lb.getMap(), lbOperands);
|
||||
|
@ -149,8 +147,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
|
||||
&stmtGroupQueue,
|
||||
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
|
||||
SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
|
||||
SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
|
||||
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
|
||||
|
||||
assert(lbMap.getNumInputs() == lbOperands.size());
|
||||
assert(ubMap.getNumInputs() == ubOperands.size());
|
||||
|
@ -176,7 +174,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
srcForStmt->getStep() * shift)),
|
||||
loopChunk)
|
||||
->getResult(0);
|
||||
operandMap[srcForStmt] = cast<MLValue>(ivRemap);
|
||||
operandMap[srcForStmt] = ivRemap;
|
||||
} else {
|
||||
operandMap[srcForStmt] = loopChunk;
|
||||
}
|
||||
|
@ -380,7 +378,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
|
@ -414,7 +412,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
|
||||
for (unsigned i = 1; i < unrollFactor; i++) {
|
||||
DenseMap<const MLValue *, MLValue *> operandMap;
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
|
@ -425,7 +423,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
->getResult(0);
|
||||
operandMap[forStmt] = cast<MLValue>(ivUnroll);
|
||||
operandMap[forStmt] = ivUnroll;
|
||||
}
|
||||
|
||||
// Clone the original body of 'forStmt'.
|
||||
|
|
|
@ -32,17 +32,17 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
// Visit affine expressions recursively and build the sequence of instructions
|
||||
// that correspond to it. Visitation functions return an SSAValue of the
|
||||
// that correspond to it. Visitation functions return an Value of the
|
||||
// expression subtree they visited or `nullptr` on error.
|
||||
class AffineApplyExpander
|
||||
: public AffineExprVisitor<AffineApplyExpander, SSAValue *> {
|
||||
: public AffineExprVisitor<AffineApplyExpander, Value *> {
|
||||
public:
|
||||
// This internal clsas expects arguments to be non-null, checks must be
|
||||
// performed at the call site.
|
||||
AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op)
|
||||
: builder(*builder), applyOp(*op), loc(op->getLoc()) {}
|
||||
|
||||
template <typename OpTy> SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) {
|
||||
template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
|
||||
auto lhs = visit(expr.getLHS());
|
||||
auto rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
|
@ -51,33 +51,33 @@ public:
|
|||
return op->getResult();
|
||||
}
|
||||
|
||||
SSAValue *visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
Value *visitAddExpr(AffineBinaryOpExpr expr) {
|
||||
return buildBinaryExpr<AddIOp>(expr);
|
||||
}
|
||||
|
||||
SSAValue *visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
Value *visitMulExpr(AffineBinaryOpExpr expr) {
|
||||
return buildBinaryExpr<MulIOp>(expr);
|
||||
}
|
||||
|
||||
// TODO(zinenko): implement when the standard operators are made available.
|
||||
SSAValue *visitModExpr(AffineBinaryOpExpr) {
|
||||
Value *visitModExpr(AffineBinaryOpExpr) {
|
||||
builder.getContext()->emitError(loc, "unsupported binary operator: mod");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) {
|
||||
Value *visitFloorDivExpr(AffineBinaryOpExpr) {
|
||||
builder.getContext()->emitError(loc,
|
||||
"unsupported binary operator: floor_div");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) {
|
||||
Value *visitCeilDivExpr(AffineBinaryOpExpr) {
|
||||
builder.getContext()->emitError(loc,
|
||||
"unsupported binary operator: ceil_div");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SSAValue *visitConstantExpr(AffineConstantExpr expr) {
|
||||
Value *visitConstantExpr(AffineConstantExpr expr) {
|
||||
auto valueAttr =
|
||||
builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
|
||||
auto op =
|
||||
|
@ -85,7 +85,7 @@ public:
|
|||
return op->getResult();
|
||||
}
|
||||
|
||||
SSAValue *visitDimExpr(AffineDimExpr expr) {
|
||||
Value *visitDimExpr(AffineDimExpr expr) {
|
||||
assert(expr.getPosition() < applyOp.getNumOperands() &&
|
||||
"affine dim position out of range");
|
||||
// FIXME: this assumes a certain order of AffineApplyOp operands, the
|
||||
|
@ -93,7 +93,7 @@ public:
|
|||
return applyOp.getOperand(expr.getPosition());
|
||||
}
|
||||
|
||||
SSAValue *visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
Value *visitSymbolExpr(AffineSymbolExpr expr) {
|
||||
// FIXME: this assumes a certain order of AffineApplyOp operands, the
|
||||
// cleaner interface would be to separate them at the op level.
|
||||
assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() <
|
||||
|
@ -114,8 +114,8 @@ private:
|
|||
// Given an affine expression `expr` extracted from `op`, build the sequence of
|
||||
// primitive instructions that correspond to the affine expression in the
|
||||
// `builder`.
|
||||
static SSAValue *expandAffineExpr(FuncBuilder *builder, AffineExpr expr,
|
||||
AffineApplyOp *op) {
|
||||
static mlir::Value *expandAffineExpr(FuncBuilder *builder, AffineExpr expr,
|
||||
AffineApplyOp *op) {
|
||||
auto expander = AffineApplyExpander(builder, op);
|
||||
return expander.visit(expr);
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
|
|||
FuncBuilder builder(op->getOperation());
|
||||
auto affineMap = op->getAffineMap();
|
||||
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
|
||||
SSAValue *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
|
||||
Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
|
||||
if (!expanded)
|
||||
return true;
|
||||
op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded);
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Return true if this operation dereferences one or more memref's.
|
||||
|
@ -61,13 +60,12 @@ static bool isMemRefDereferencingOp(const Operation &op) {
|
|||
// extra operands, note that 'indexRemap' would just be applied to the existing
|
||||
// indices (%i, %j).
|
||||
//
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
||||
MLValue *newMemRef,
|
||||
ArrayRef<SSAValue *> extraIndices,
|
||||
bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
||||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<SSAValue *> extraOperands,
|
||||
ArrayRef<Value *> extraOperands,
|
||||
const Statement *domStmtFilter) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
|
@ -128,16 +126,15 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
// operation.
|
||||
assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
|
||||
"single result op's expected to generate these indices");
|
||||
assert((cast<MLValue>(extraIndex)->isValidDim() ||
|
||||
cast<MLValue>(extraIndex)->isValidSymbol()) &&
|
||||
assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) &&
|
||||
"invalid memory op index");
|
||||
state.operands.push_back(cast<MLValue>(extraIndex));
|
||||
state.operands.push_back(extraIndex);
|
||||
}
|
||||
|
||||
// Construct new indices as a remap of the old ones if a remapping has been
|
||||
// provided. The indices of a memref come right after it, i.e.,
|
||||
// at position memRefOperandPos + 1.
|
||||
SmallVector<SSAValue *, 4> remapOperands;
|
||||
SmallVector<Value *, 4> remapOperands;
|
||||
remapOperands.reserve(oldMemRefRank + extraOperands.size());
|
||||
remapOperands.insert(remapOperands.end(), extraOperands.begin(),
|
||||
extraOperands.end());
|
||||
|
@ -149,11 +146,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
remapOperands);
|
||||
// Remapped indices.
|
||||
for (auto *index : remapOp->getOperation()->getResults())
|
||||
state.operands.push_back(cast<MLValue>(index));
|
||||
state.operands.push_back(index);
|
||||
} else {
|
||||
// No remapping specified.
|
||||
for (auto *index : remapOperands)
|
||||
state.operands.push_back(cast<MLValue>(index));
|
||||
state.operands.push_back(index);
|
||||
}
|
||||
|
||||
// Insert the remaining operands unmodified.
|
||||
|
@ -191,9 +188,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
|||
// composed AffineApplyOp are returned in output parameter 'results'.
|
||||
OperationStmt *
|
||||
mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<OperationStmt *> affineApplyOps,
|
||||
SmallVectorImpl<SSAValue *> *results) {
|
||||
SmallVectorImpl<Value *> *results) {
|
||||
// Create identity map with same number of dimensions as number of operands.
|
||||
auto map = builder->getMultiDimIdentityMap(operands.size());
|
||||
// Initialize AffineValueMap with identity map.
|
||||
|
@ -208,7 +205,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
// Compose affine maps from all ancestor AffineApplyOps.
|
||||
// Create new AffineApplyOp from 'valueMap'.
|
||||
unsigned numOperands = valueMap.getNumOperands();
|
||||
SmallVector<SSAValue *, 4> outOperands(numOperands);
|
||||
SmallVector<Value *, 4> outOperands(numOperands);
|
||||
for (unsigned i = 0; i < numOperands; ++i) {
|
||||
outOperands[i] = valueMap.getOperand(i);
|
||||
}
|
||||
|
@ -252,7 +249,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
/// otherwise.
|
||||
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
|
||||
// Collect all operands that are results of affine apply ops.
|
||||
SmallVector<MLValue *, 4> subOperands;
|
||||
SmallVector<Value *, 4> subOperands;
|
||||
subOperands.reserve(opStmt->getNumOperands());
|
||||
for (auto *operand : opStmt->getOperands()) {
|
||||
auto *defStmt = operand->getDefiningStmt();
|
||||
|
@ -285,7 +282,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
|
|||
return nullptr;
|
||||
|
||||
FuncBuilder builder(opStmt);
|
||||
SmallVector<SSAValue *, 4> results;
|
||||
SmallVector<Value *, 4> results;
|
||||
auto *affineApplyStmt = createComposedAffineApplyOp(
|
||||
&builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
|
||||
assert(results.size() == subOperands.size() &&
|
||||
|
@ -295,7 +292,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
|
|||
// affine apply op above instead of existing ones (subOperands). So, they
|
||||
// differ from opStmt's operands only for those operands in 'subOperands', for
|
||||
// which they will be replaced by the corresponding one from 'results'.
|
||||
SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
|
||||
SmallVector<Value *, 4> newOperands(opStmt->getOperands());
|
||||
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
|
||||
// Replace the subOperands from among the new operands.
|
||||
unsigned j, f;
|
||||
|
@ -304,7 +301,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
|
|||
break;
|
||||
}
|
||||
if (j < subOperands.size()) {
|
||||
newOperands[i] = cast<MLValue>(results[j]);
|
||||
newOperands[i] = results[j];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,7 +323,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
// into any uses which are AffineApplyOps.
|
||||
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
|
||||
++resultIndex) {
|
||||
const MLValue *result = opStmt->getResult(resultIndex);
|
||||
const Value *result = opStmt->getResult(resultIndex);
|
||||
for (auto it = result->use_begin(); it != result->use_end();) {
|
||||
StmtOperand &use = *(it++);
|
||||
auto *useStmt = use.getOwner();
|
||||
|
@ -347,7 +344,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
|
||||
// Create new AffineApplyOp from 'valueMap'.
|
||||
unsigned numOperands = valueMap.getNumOperands();
|
||||
SmallVector<SSAValue *, 4> operands(numOperands);
|
||||
SmallVector<Value *, 4> operands(numOperands);
|
||||
for (unsigned i = 0; i < numOperands; ++i) {
|
||||
operands[i] = valueMap.getOperand(i);
|
||||
}
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLValue.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
|
@ -740,8 +738,8 @@ struct VectorizationState {
|
|||
DenseSet<OperationStmt *> vectorizedSet;
|
||||
// Map of old scalar OperationStmt to new vectorized OperationStmt.
|
||||
DenseMap<OperationStmt *, OperationStmt *> vectorizationMap;
|
||||
// Map of old scalar MLValue to new vectorized MLValue.
|
||||
DenseMap<const MLValue *, MLValue *> replacementMap;
|
||||
// Map of old scalar Value to new vectorized Value.
|
||||
DenseMap<const Value *, Value *> replacementMap;
|
||||
// The strategy drives which loop to vectorize by which amount.
|
||||
const VectorizationStrategy *strategy;
|
||||
// Use-def roots. These represent the starting points for the worklist in the
|
||||
|
@ -761,7 +759,7 @@ struct VectorizationState {
|
|||
void registerTerminator(OperationStmt *stmt);
|
||||
|
||||
private:
|
||||
void registerReplacement(const SSAValue *key, SSAValue *value);
|
||||
void registerReplacement(const Value *key, Value *value);
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
@ -802,12 +800,9 @@ void VectorizationState::finishVectorizationPattern() {
|
|||
}
|
||||
}
|
||||
|
||||
void VectorizationState::registerReplacement(const SSAValue *key,
|
||||
SSAValue *value) {
|
||||
assert(replacementMap.count(cast<MLValue>(key)) == 0 &&
|
||||
"replacement already registered");
|
||||
replacementMap.insert(
|
||||
std::make_pair(cast<MLValue>(key), cast<MLValue>(value)));
|
||||
void VectorizationState::registerReplacement(const Value *key, Value *value) {
|
||||
assert(replacementMap.count(key) == 0 && "replacement already registered");
|
||||
replacementMap.insert(std::make_pair(key, value));
|
||||
}
|
||||
|
||||
////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ////
|
||||
|
@ -825,7 +820,7 @@ void VectorizationState::registerReplacement(const SSAValue *key,
|
|||
/// Such special cases force us to delay the vectorization of the stores
|
||||
/// until the last step. Here we merely register the store operation.
|
||||
template <typename LoadOrStoreOpPointer>
|
||||
static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
|
||||
static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
|
||||
VectorizationState *state) {
|
||||
auto memRefType =
|
||||
memoryOp->getMemRef()->getType().template cast<MemRefType>();
|
||||
|
@ -850,8 +845,7 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
|
|||
MLFuncBuilder b(opStmt);
|
||||
auto transfer = b.create<VectorTransferReadOp>(
|
||||
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
|
||||
map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()),
|
||||
permutationMap);
|
||||
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
|
||||
state->registerReplacement(opStmt,
|
||||
cast<OperationStmt>(transfer->getOperation()));
|
||||
} else {
|
||||
|
@ -970,8 +964,8 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
|
|||
/// element type.
|
||||
/// If `type` is not a valid vector type or if the scalar constant is not a
|
||||
/// valid vector element type, returns nullptr.
|
||||
static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
||||
Type type) {
|
||||
static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
||||
Type type) {
|
||||
if (!type || !type.isa<VectorType>() ||
|
||||
!VectorType::isValidElementType(constant.getType())) {
|
||||
return nullptr;
|
||||
|
@ -988,7 +982,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
|||
{make_pair(Identifier::get("value", b.getContext()), attr)});
|
||||
|
||||
auto *splat = cast<OperationStmt>(b.createOperation(state));
|
||||
return cast<MLValue>(splat->getResult(0));
|
||||
return splat->getResult(0);
|
||||
}
|
||||
|
||||
/// Returns a uniqu'ed VectorType.
|
||||
|
@ -996,7 +990,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
|||
/// vectorizedSet, just returns the type of `v`.
|
||||
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
|
||||
/// and of elemental type the type of `v`.
|
||||
static Type getVectorType(SSAValue *v, const VectorizationState &state) {
|
||||
static Type getVectorType(Value *v, const VectorizationState &state) {
|
||||
if (!VectorType::isValidElementType(v->getType())) {
|
||||
return Type();
|
||||
}
|
||||
|
@ -1028,23 +1022,23 @@ static Type getVectorType(SSAValue *v, const VectorizationState &state) {
|
|||
/// vectorization is possible with the above logic. Returns nullptr otherwise.
|
||||
///
|
||||
/// TODO(ntv): handle more complex cases.
|
||||
static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt,
|
||||
VectorizationState *state) {
|
||||
static Value *vectorizeOperand(Value *operand, Statement *stmt,
|
||||
VectorizationState *state) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
|
||||
LLVM_DEBUG(operand->print(dbgs()));
|
||||
auto *definingStatement = cast<OperationStmt>(operand->getDefiningStmt());
|
||||
// 1. If this value has already been vectorized this round, we are done.
|
||||
if (state->vectorizedSet.count(definingStatement) > 0) {
|
||||
LLVM_DEBUG(dbgs() << " -> already vector operand");
|
||||
return cast<MLValue>(operand);
|
||||
return operand;
|
||||
}
|
||||
// 1.b. Delayed on-demand replacement of a use.
|
||||
// Note that we cannot just call replaceAllUsesWith because it may result
|
||||
// in ops with mixed types, for ops whose operands have not all yet
|
||||
// been vectorized. This would be invalid IR.
|
||||
auto it = state->replacementMap.find(cast<MLValue>(operand));
|
||||
auto it = state->replacementMap.find(operand);
|
||||
if (it != state->replacementMap.end()) {
|
||||
auto *res = cast<MLValue>(it->second);
|
||||
auto *res = it->second;
|
||||
LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
|
||||
LLVM_DEBUG(res->print(dbgs()));
|
||||
return res;
|
||||
|
@ -1089,7 +1083,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
|
|||
auto *memRef = store->getMemRef();
|
||||
auto *value = store->getValueToStore();
|
||||
auto *vectorValue = vectorizeOperand(value, opStmt, state);
|
||||
auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices());
|
||||
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
|
||||
MLFuncBuilder b(opStmt);
|
||||
auto permutationMap =
|
||||
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
|
||||
|
@ -1104,14 +1098,14 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
|
|||
return res;
|
||||
}
|
||||
|
||||
auto types = map([state](SSAValue *v) { return getVectorType(v, *state); },
|
||||
auto types = map([state](Value *v) { return getVectorType(v, *state); },
|
||||
opStmt->getResults());
|
||||
auto vectorizeOneOperand = [opStmt, state](SSAValue *op) -> SSAValue * {
|
||||
auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
|
||||
return vectorizeOperand(op, opStmt, state);
|
||||
};
|
||||
auto operands = map(vectorizeOneOperand, opStmt->getOperands());
|
||||
// Check whether a single operand is null. If so, vectorization failed.
|
||||
bool success = llvm::all_of(operands, [](SSAValue *op) { return op; });
|
||||
bool success = llvm::all_of(operands, [](Value *op) { return op; });
|
||||
if (!success) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
|
||||
return nullptr;
|
||||
|
@ -1207,7 +1201,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
continue;
|
||||
}
|
||||
MLFuncBuilder builder(loop); // builder to insert in place of loop
|
||||
DenseMap<const MLValue *, MLValue *> nomap;
|
||||
DenseMap<const Value *, Value *> nomap;
|
||||
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
|
||||
auto fail = doVectorize(m, &state);
|
||||
/// Sets up error handling for this root loop. This is how the root match
|
||||
|
@ -1229,8 +1223,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
// Form the root operationsthat have been set in the replacementMap.
|
||||
// For now, these roots are the loads for which vector_transfer_read
|
||||
// operations have been inserted.
|
||||
auto getDefiningOperation = [](const MLValue *val) {
|
||||
return const_cast<MLValue *>(val)->getDefiningOperation();
|
||||
auto getDefiningOperation = [](const Value *val) {
|
||||
return const_cast<Value *>(val)->getDefiningOperation();
|
||||
};
|
||||
using ReferenceTy = decltype(*(state.replacementMap.begin()));
|
||||
auto getKey = [](ReferenceTy it) { return it.first; };
|
||||
|
|
|
@ -288,10 +288,10 @@ void OpEmitter::emitAttrGetters() {
|
|||
}
|
||||
|
||||
void OpEmitter::emitNamedOperands() {
|
||||
const auto operandMethods = R"( SSAValue *{0}() {
|
||||
const auto operandMethods = R"( Value *{0}() {
|
||||
return this->getOperation()->getOperand({1});
|
||||
}
|
||||
const SSAValue *{0}() const {
|
||||
const Value *{0}() const {
|
||||
return this->getOperation()->getOperand({1});
|
||||
}
|
||||
)";
|
||||
|
@ -329,7 +329,7 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Emit parameters for all operands
|
||||
for (const auto &pair : operands)
|
||||
os << ", SSAValue* " << pair.first;
|
||||
os << ", Value* " << pair.first;
|
||||
|
||||
// Emit parameters for all attributes
|
||||
// TODO(antiagainst): Support default initializer for attributes
|
||||
|
@ -369,7 +369,7 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Signature
|
||||
os << " static void build(Builder* builder, OperationState* result, "
|
||||
<< "ArrayRef<Type> resultTypes, ArrayRef<SSAValue*> args, "
|
||||
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
|
||||
"ArrayRef<NamedAttribute> attributes) {\n";
|
||||
|
||||
// Result types
|
||||
|
|
Loading…
Reference in New Issue