Merge SSAValue, CFGValue, and MLValue together into a single Value class, which

is the new base of the SSA value hierarchy.  This CL also standardizes all the
nomenclature and comments to use 'Value' where appropriate.  This also eliminates a large number of cast<MLValue>(x)'s, which is very soothing.

This is step 11/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227064624
This commit is contained in:
Chris Lattner 2018-12-27 14:35:10 -08:00 committed by jpienaar
parent 776b035646
commit 3f190312f8
61 changed files with 959 additions and 1162 deletions

View File

@ -37,9 +37,9 @@ class ForStmt;
class MLIRContext;
class FlatAffineConstraints;
class IntegerSet;
class MLValue;
class OperationStmt;
class Statement;
class Value;
/// Simplify an affine expression through flattening and some amount of
/// simple analysis. This has complexity linear in the number of nodes in
@ -78,7 +78,7 @@ AffineExpr composeWithUnboundedMap(AffineExpr e, AffineMap g);
/// 'affineApplyOps', which are reachable via a search starting from 'operands',
/// and ending at operands which are not defined by AffineApplyOps.
void getReachableAffineApplyOps(
llvm::ArrayRef<MLValue *> operands,
llvm::ArrayRef<Value *> operands,
llvm::SmallVectorImpl<OperationStmt *> &affineApplyOps);
/// Forward substitutes into 'valueMap' all AffineApplyOps reachable from the
@ -122,9 +122,9 @@ bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain);
struct MemRefAccess {
const MLValue *memref;
const Value *memref;
const OperationStmt *opStmt;
llvm::SmallVector<MLValue *, 4> indices;
llvm::SmallVector<Value *, 4> indices;
// Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'.
void getAccessMap(AffineValueMap *accessMap) const;

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
@ -37,7 +36,7 @@ class AffineMap;
class ForStmt;
class IntegerSet;
class MLIRContext;
class MLValue;
class Value;
class HyperRectangularSet;
/// A mutable affine map. Its affine expressions are however unique.
@ -132,7 +131,7 @@ public:
AffineValueMap(const AffineApplyOp &op);
AffineValueMap(const AffineBound &bound);
AffineValueMap(AffineMap map);
AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands);
AffineValueMap(AffineMap map, ArrayRef<Value *> operands);
~AffineValueMap();
@ -155,13 +154,13 @@ public:
// substitutions).
// Resets this AffineValueMap with 'map' and 'operands'.
void reset(AffineMap map, ArrayRef<MLValue *> operands);
void reset(AffineMap map, ArrayRef<Value *> operands);
/// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise.
inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the idx^th result depends on 'value', false otherwise.
bool isFunctionOf(unsigned idx, MLValue *value) const;
bool isFunctionOf(unsigned idx, Value *value) const;
/// Return true if the result at 'idx' is a constant, false
/// otherwise.
@ -175,8 +174,8 @@ public:
inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
inline unsigned getNumResults() const { return map.getNumResults(); }
SSAValue *getOperand(unsigned i) const;
ArrayRef<MLValue *> getOperands() const;
Value *getOperand(unsigned i) const;
ArrayRef<Value *> getOperands() const;
AffineMap getAffineMap() const;
private:
@ -187,9 +186,9 @@ private:
// TODO: make these trailing objects?
/// The SSA operands binding to the dim's and symbols of 'map'.
SmallVector<MLValue *, 4> operands;
SmallVector<Value *, 4> operands;
/// The SSA results binding to the results of 'map'.
SmallVector<MLValue *, 4> results;
SmallVector<Value *, 4> results;
};
/// An IntegerValueSet is an integer set plus its operands.
@ -218,7 +217,7 @@ private:
// 'AffineCondition'.
MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'.
SmallVector<MLValue *, 4> operands;
SmallVector<Value *, 4> operands;
};
/// A flat list of affine equalities and inequalities in the form.
@ -250,7 +249,7 @@ public:
unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims = 0,
unsigned numSymbols = 0, unsigned numLocals = 0,
ArrayRef<Optional<MLValue *>> idArgs = {})
ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numReservedCols), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
@ -269,7 +268,7 @@ public:
/// dimensions and symbols.
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0,
ArrayRef<Optional<MLValue *>> idArgs = {})
ArrayRef<Optional<Value *>> idArgs = {})
: numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
@ -309,10 +308,10 @@ public:
// Clears any existing data and reserves memory for the specified constraints.
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
void reset(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0, ArrayRef<MLValue *> idArgs = {});
unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
/// Appends constraints from 'other' into this. This is equivalent to an
/// intersection with no simplification of any sort attempted.
@ -393,7 +392,7 @@ public:
// Returns AffineMap::Null on error (i.e. if coefficient is zero or does
// not divide other coefficients in the equality constraint).
// TODO(andydavis) Remove 'nonZeroDimIds' and 'nonZeroSymbolIds' from this
// API when we can manage the mapping of MLValues and ids in the constraint
// API when we can manage the mapping of Values and ids in the constraint
// system.
AffineMap toAffineMapFromEq(unsigned idx, unsigned pos, MLIRContext *context,
SmallVectorImpl<unsigned> *nonZeroDimIds,
@ -413,10 +412,10 @@ public:
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
/// Adds constraints (lower and upper bounds) for the specified 'for'
/// statement's MLValue using IR information stored in its bound maps. The
/// right identifier is first looked up using forStmt's MLValue. Returns
/// statement's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forStmt's Value. Returns
/// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the MLValue corresponding to
/// information is succesfully added. Asserts if the Value corresponding to
/// the 'for' statement isn't found in the constraint system. Any new
/// identifiers that are found in the bound operands of the 'for' statement
/// are added as trailing identifiers (either dimensional or symbolic
@ -435,28 +434,28 @@ public:
/// Sets the identifier at the specified position to a constant.
void setIdToConstant(unsigned pos, int64_t val);
/// Sets the identifier corresponding to the specified MLValue id to a
/// Sets the identifier corresponding to the specified Value id to a
/// constant. Asserts if the 'id' is not found.
void setIdToConstant(const MLValue &id, int64_t val);
void setIdToConstant(const Value &id, int64_t val);
/// Looks up the identifier with the specified MLValue. Returns false if not
/// Looks up the identifier with the specified Value. Returns false if not
/// found, true if found. pos is set to the (column) position of the
/// identifier.
bool findId(const MLValue &id, unsigned *pos) const;
bool findId(const Value &id, unsigned *pos) const;
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. 'id' is the MLValue corresponding to the
// the kind of identifier. 'id' is the Value corresponding to the
// identifier that can optionally be provided.
void addDimId(unsigned pos, MLValue *id = nullptr);
void addSymbolId(unsigned pos, MLValue *id = nullptr);
void addDimId(unsigned pos, Value *id = nullptr);
void addSymbolId(unsigned pos, Value *id = nullptr);
void addLocalId(unsigned pos);
void addId(IdKind kind, unsigned pos, MLValue *id = nullptr);
void addId(IdKind kind, unsigned pos, Value *id = nullptr);
/// Composes the affine value map with this FlatAffineConstrains, adding the
/// results of the map as dimensions at the front [0, vMap->getNumResults())
/// and with the dimensions set to the equalities specified by the value map.
/// Returns false if the composition fails (when vMap is a semi-affine map).
/// The vMap's operand MLValue's are used to look up the right positions in
/// The vMap's operand Value's are used to look up the right positions in
/// the FlatAffineConstraints with which to associate. The dimensional and
/// symbolic operands of vMap should match 1:1 (in the same order) with those
/// of this constraint system, but the latter could have additional trailing
@ -471,8 +470,8 @@ public:
void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
/// Projects out the identifier that is associate with MLValue *.
void projectOut(MLValue *id);
/// Projects out the identifier that is associate with Value *.
void projectOut(Value *id);
void removeId(IdKind idKind, unsigned pos);
void removeId(unsigned pos);
@ -510,24 +509,24 @@ public:
return numIds - numDims - numSymbols;
}
inline ArrayRef<Optional<MLValue *>> getIds() const {
inline ArrayRef<Optional<Value *>> getIds() const {
return {ids.data(), ids.size()};
}
/// Returns the MLValue's associated with the identifiers. Asserts if
/// no MLValue was associated with an identifier.
inline void getIdValues(SmallVectorImpl<MLValue *> *values) const {
/// Returns the Value's associated with the identifiers. Asserts if
/// no Value was associated with an identifier.
inline void getIdValues(SmallVectorImpl<Value *> *values) const {
values->clear();
values->reserve(numIds);
for (unsigned i = 0; i < numIds; i++) {
assert(ids[i].hasValue() && "identifier's MLValue not set");
assert(ids[i].hasValue() && "identifier's Value not set");
values->push_back(ids[i].getValue());
}
}
/// Returns the MLValue associated with the pos^th identifier. Asserts if
/// no MLValue identifier was associated.
inline MLValue *getIdValue(unsigned pos) const {
/// Returns the Value associated with the pos^th identifier. Asserts if
/// no Value identifier was associated.
inline Value *getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's ML Value not set");
return ids[pos].getValue();
}
@ -630,11 +629,11 @@ private:
/// analysis).
unsigned numSymbols;
/// MLValues corresponding to the (column) identifiers of this constraint
/// Values corresponding to the (column) identifiers of this constraint
/// system appearing in the order the identifiers correspond to columns.
/// Temporary ones or those that aren't associated to any MLValue are to be
/// Temporary ones or those that aren't associated to any Value are to be
/// set to None.
SmallVector<Optional<MLValue *>, 8> ids;
SmallVector<Optional<Value *>, 8> ids;
};
} // end namespace mlir.

View File

@ -58,10 +58,10 @@ public:
}
/// Return true if value A properly dominates instruction B.
bool properlyDominates(const SSAValue *a, const Instruction *b);
bool properlyDominates(const Value *a, const Instruction *b);
/// Return true if instruction A dominates instruction B.
bool dominates(const SSAValue *a, const Instruction *b) {
bool dominates(const Value *a, const Instruction *b) {
return (Instruction *)a->getDefiningInst() == b || properlyDominates(a, b);
}

View File

@ -38,10 +38,10 @@ class AffineCondition;
class AffineMap;
class IntegerSet;
class MLIRContext;
class MLValue;
class MutableIntegerSet;
class FlatAffineConstraints;
class HyperRectangleList;
class Value;
/// A list of affine bounds.
// Not using a MutableAffineMap here since numSymbols is the same as the
@ -152,8 +152,8 @@ private:
// expressions.
std::vector<AffineBoundExprList> upperBounds;
Optional<SmallVector<MLValue *, 8>> dims = None;
Optional<SmallVector<MLValue *, 4>> symbols = None;
Optional<SmallVector<Value *, 8>> dims = None;
Optional<SmallVector<Value *, 4>> symbols = None;
/// Number of real dimensions.
unsigned numDims;

View File

@ -23,7 +23,6 @@
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
@ -33,8 +32,8 @@ class AffineExpr;
class AffineMap;
class ForStmt;
class MemRefType;
class MLValue;
class OperationStmt;
class Value;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
@ -66,7 +65,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
///
/// Returns false in cases with more than one AffineApplyOp, this is
/// conservative.
bool isAccessInvariant(const MLValue &iv, const MLValue &index);
bool isAccessInvariant(const Value &iv, const Value &index);
/// Given an induction variable `iv` of type ForStmt and `indices` of type
/// IndexType, returns the set of `indices` that are independent of `iv`.
@ -77,9 +76,8 @@ bool isAccessInvariant(const MLValue &iv, const MLValue &index);
///
/// Returns false in cases with more than one AffineApplyOp, this is
/// conservative.
llvm::DenseSet<const MLValue *, llvm::DenseMapInfo<const MLValue *>>
getInvariantAccesses(const MLValue &iv,
llvm::ArrayRef<const MLValue *> indices);
llvm::DenseSet<const Value *, llvm::DenseMapInfo<const Value *>>
getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
/// Checks whether the loop is structurally vectorizable; i.e.:
/// 1. the loop has proper dependence semantics (parallel, reduction, etc);

View File

@ -127,10 +127,10 @@ void getBackwardSlice(
/// **includes** the original statement.
///
/// This allows building a slice (i.e. multi-root DAG where everything
/// that is reachable from an SSAValue in forward and backward direction is
/// that is reachable from an Value in forward and backward direction is
/// contained in the slice).
/// This is the abstraction we need to materialize all the instructions for
/// supervectorization without worrying about orderings and SSAValue
/// supervectorization without worrying about orderings and Value
/// replacements.
///
/// Example starting from any node

View File

@ -34,10 +34,10 @@ namespace mlir {
class FlatAffineConstraints;
class ForStmt;
class MLValue;
class MemRefAccess;
class OperationStmt;
class Statement;
class Value;
/// Returns true if statement 'a' dominates statement b.
bool dominates(const Statement &a, const Statement &b);
@ -92,7 +92,7 @@ struct MemRefRegion {
unsigned getRank() const;
/// Memref that this region corresponds to.
MLValue *memref;
Value *memref;
private:
/// Read or write.

View File

@ -408,8 +408,8 @@ public:
}
// Creates a for statement. When step is not specified, it is set to 1.
ForStmt *createFor(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
ForStmt *createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step = 1);
// Creates a for statement with known (constant) lower and upper bounds.
@ -417,7 +417,7 @@ public:
ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
/// Creates if statement.
IfStmt *createIf(Location location, ArrayRef<MLValue *> operands,
IfStmt *createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set);
private:

View File

@ -30,7 +30,6 @@
namespace mlir {
class Builder;
class MLValue;
class BuiltinDialect : public Dialect {
public:
@ -57,7 +56,7 @@ class AffineApplyOp
public:
/// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<SSAValue *> operands);
ArrayRef<Value *> operands);
/// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const {
@ -101,7 +100,7 @@ public:
static StringRef getOperationName() { return "br"; }
static void build(Builder *builder, OperationState *result, BasicBlock *dest,
ArrayRef<SSAValue *> operands = {});
ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
@ -144,10 +143,9 @@ class CondBranchOp : public Op<CondBranchOp, OpTrait::AtLeastNOperands<1>::Impl,
public:
static StringRef getOperationName() { return "cond_br"; }
static void build(Builder *builder, OperationState *result,
SSAValue *condition, BasicBlock *trueDest,
ArrayRef<SSAValue *> trueOperands, BasicBlock *falseDest,
ArrayRef<SSAValue *> falseOperands);
static void build(Builder *builder, OperationState *result, Value *condition,
BasicBlock *trueDest, ArrayRef<Value *> trueOperands,
BasicBlock *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
@ -155,8 +153,8 @@ public:
bool verify() const;
// The condition operand is the first operand in the list.
SSAValue *getCondition() { return getOperand(0); }
const SSAValue *getCondition() const { return getOperand(0); }
Value *getCondition() { return getOperand(0); }
const Value *getCondition() const { return getOperand(0); }
/// Return the destination if the condition is true.
BasicBlock *getTrueDest() const;
@ -165,14 +163,14 @@ public:
BasicBlock *getFalseDest() const;
// Accessors for operands to the 'true' destination.
SSAValue *getTrueOperand(unsigned idx) {
Value *getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
const SSAValue *getTrueOperand(unsigned idx) const {
const Value *getTrueOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getTrueOperand(idx);
}
void setTrueOperand(unsigned idx, SSAValue *value) {
void setTrueOperand(unsigned idx, Value *value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
@ -199,14 +197,14 @@ public:
void eraseTrueOperand(unsigned index);
// Accessors for operands to the 'false' destination.
SSAValue *getFalseOperand(unsigned idx) {
Value *getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
const SSAValue *getFalseOperand(unsigned idx) const {
const Value *getFalseOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
}
void setFalseOperand(unsigned idx, SSAValue *value) {
void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
@ -361,7 +359,7 @@ public:
static StringRef getOperationName() { return "return"; }
static void build(Builder *builder, OperationState *result,
ArrayRef<SSAValue *> results = {});
ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
@ -380,7 +378,7 @@ void printDimAndSymbolList(Operation::const_operand_iterator begin,
// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<SSAValue *, 4> &operands,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
} // end namespace mlir

View File

@ -27,7 +27,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StmtBlock.h"
#include "mlir/IR/Types.h"

View File

@ -1,133 +0,0 @@
//===- MLValue.h - MLValue base class and SSA type decls ------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines SSA manipulation implementations for ML functions.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLVALUE_H
#define MLIR_IR_MLVALUE_H
#include "mlir/IR/SSAValue.h"
namespace mlir {
class ForStmt;
class MLValue;
using MLFunction = Function;
class Statement;
class StmtBlock;
/// This enum contains all of the SSA value kinds that are valid in an ML
/// function. This should be kept as a proper subtype of SSAValueKind,
/// including having all of the values of the enumerators align.
enum class MLValueKind {
BlockArgument = (int)SSAValueKind::BlockArgument,
StmtResult = (int)SSAValueKind::StmtResult,
ForStmt = (int)SSAValueKind::ForStmt,
};
/// The operand of ML function statement contains an MLValue.
using StmtOperand = IROperandImpl<MLValue, Statement>;
/// MLValue is the base class for SSA values in ML functions.
class MLValue : public SSAValueImpl<StmtOperand, Statement, MLValueKind> {
public:
/// Returns true if the given MLValue can be used as a dimension id.
bool isValidDim() const;
/// Returns true if the given MLValue can be used as a symbol.
bool isValidSymbol() const;
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
return true;
}
}
/// Return the function that this MLValue is defined in.
MLFunction *getFunction();
/// Return the function that this MLValue is defined in.
const MLFunction *getFunction() const {
return const_cast<MLValue *>(this)->getFunction();
}
protected:
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
};
/// Block arguments are ML Values.
class BlockArgument : public MLValue {
public:
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::BlockArgument;
}
/// Return the function that this argument is defined in.
MLFunction *getFunction();
const MLFunction *getFunction() const {
return const_cast<BlockArgument *>(this)->getFunction();
}
StmtBlock *getOwner() { return owner; }
const StmtBlock *getOwner() const { return owner; }
private:
friend class StmtBlock; // For access to private constructor.
BlockArgument(Type type, StmtBlock *owner)
: MLValue(MLValueKind::BlockArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
StmtBlock *const owner;
};
/// This is a value defined by a result of an operation instruction.
class StmtResult : public MLValue {
public:
StmtResult(Type type, OperationStmt *owner)
: MLValue(MLValueKind::StmtResult, type), owner(owner) {}
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::StmtResult;
}
OperationStmt *getOwner() { return owner; }
const OperationStmt *getOwner() const { return owner; }
/// Returns the number of this result.
unsigned getResultNumber() const;
private:
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
OperationStmt *const owner;
};
// TODO(clattner) clean all this up.
using CFGValue = MLValue;
using BBArgument = BlockArgument;
using InstResult = StmtResult;
} // namespace mlir
#endif

View File

@ -27,8 +27,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include <type_traits>
namespace mlir {
@ -107,7 +107,7 @@ template <typename OpClass> struct op_matcher {
/// Entry point for matching a pattern over an SSAValue.
template <typename Pattern>
inline bool matchPattern(SSAValue *value, const Pattern &pattern) {
inline bool matchPattern(Value *value, const Pattern &pattern) {
// TODO: handle other cases
if (auto *op = value->getDefiningOperation())
return const_cast<Pattern &>(pattern).match(op);

View File

@ -29,7 +29,7 @@
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Operation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Value.h"
#include <type_traits>
namespace mlir {
@ -78,11 +78,11 @@ public:
}
/// If the OpType operation includes the OneResult trait, then OpPointer can
/// be implicitly converted to an SSAValue*. This yields the value of the
/// be implicitly converted to an Value*. This yields the value of the
/// only result.
template <typename SFINAE = OpType>
operator typename std::enable_if<IsSingleResult<SFINAE>::value,
SSAValue *>::type() {
Value *>::type() {
return value.getResult();
}
@ -114,14 +114,14 @@ public:
}
/// If the OpType operation includes the OneResult trait, then OpPointer can
/// be implicitly converted to an const SSAValue*. This yields the value of
/// be implicitly converted to an const Value*. This yields the value of
/// the only result.
template <typename SFINAE = OpType>
operator typename std::enable_if<
std::is_convertible<
SFINAE *,
OpTrait::OneResult<typename SFINAE::ConcreteOpType> *>::value,
const SSAValue *>::type() const {
const Value *>::type() const {
return value.getResult();
}
@ -346,15 +346,13 @@ private:
template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
const SSAValue *getOperand() const {
const Value *getOperand() const {
return this->getOperation()->getOperand(0);
}
SSAValue *getOperand() { return this->getOperation()->getOperand(0); }
Value *getOperand() { return this->getOperation()->getOperand(0); }
void setOperand(SSAValue *value) {
this->getOperation()->setOperand(0, value);
}
void setOperand(Value *value) { this->getOperation()->setOperand(0, value); }
static bool verifyTrait(const Operation *op) {
return impl::verifyOneOperand(op);
@ -371,15 +369,15 @@ public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
public:
const SSAValue *getOperand(unsigned i) const {
const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i);
}
SSAValue *getOperand(unsigned i) {
Value *getOperand(unsigned i) {
return this->getOperation()->getOperand(i);
}
void setOperand(unsigned i, SSAValue *value) {
void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value);
}
@ -402,15 +400,15 @@ public:
unsigned getNumOperands() const {
return this->getOperation()->getNumOperands();
}
const SSAValue *getOperand(unsigned i) const {
const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i);
}
SSAValue *getOperand(unsigned i) {
Value *getOperand(unsigned i) {
return this->getOperation()->getOperand(i);
}
void setOperand(unsigned i, SSAValue *value) {
void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value);
}
@ -453,15 +451,13 @@ public:
return this->getOperation()->getNumOperands();
}
const SSAValue *getOperand(unsigned i) const {
const Value *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i);
}
SSAValue *getOperand(unsigned i) {
return this->getOperation()->getOperand(i);
}
Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
void setOperand(unsigned i, SSAValue *value) {
void setOperand(unsigned i, Value *value) {
this->getOperation()->setOperand(i, value);
}
@ -503,17 +499,15 @@ public:
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
SSAValue *getResult() { return this->getOperation()->getResult(0); }
const SSAValue *getResult() const {
return this->getOperation()->getResult(0);
}
Value *getResult() { return this->getOperation()->getResult(0); }
const Value *getResult() const { return this->getOperation()->getResult(0); }
Type getType() const { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(SSAValue *newValue) {
void replaceAllUsesWith(Value *newValue) {
getResult()->replaceAllUsesWith(newValue);
}
@ -548,13 +542,11 @@ public:
public:
static unsigned getNumResults() { return N; }
const SSAValue *getResult(unsigned i) const {
const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i);
}
SSAValue *getResult(unsigned i) {
return this->getOperation()->getResult(i);
}
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
Type getType(unsigned i) const { return getResult(i)->getType(); }
@ -574,13 +566,11 @@ public:
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> {
public:
const SSAValue *getResult(unsigned i) const {
const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i);
}
SSAValue *getResult(unsigned i) {
return this->getOperation()->getResult(i);
}
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
Type getType(unsigned i) const { return getResult(i)->getType(); }
@ -599,13 +589,13 @@ public:
return this->getOperation()->getNumResults();
}
const SSAValue *getResult(unsigned i) const {
const Value *getResult(unsigned i) const {
return this->getOperation()->getResult(i);
}
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
void setResult(unsigned i, SSAValue *value) {
void setResult(unsigned i, Value *value) {
this->getOperation()->setResult(i, value);
}
@ -762,10 +752,10 @@ public:
return this->getOperation()->setSuccessor(block, index);
}
void addSuccessorOperand(unsigned index, SSAValue *value) {
void addSuccessorOperand(unsigned index, Value *value) {
return this->getOperation()->addSuccessorOperand(index, value);
}
void addSuccessorOperands(unsigned index, ArrayRef<SSAValue *> values) {
void addSuccessorOperands(unsigned index, ArrayRef<Value *> values) {
return this->getOperation()->addSuccessorOperand(index, values);
}
};
@ -889,8 +879,8 @@ private:
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
void buildBinaryOp(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs);
void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs);
bool parseBinaryOp(OpAsmParser *parser, OperationState *result);
void printBinaryOp(const Operation *op, OpAsmPrinter *p);
} // namespace impl
@ -906,8 +896,8 @@ class BinaryOp
: public Op<ConcreteType, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
OpTrait::SameOperandsAndResultType, Traits...> {
public:
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs) {
static void build(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
impl::buildBinaryOp(builder, result, lhs, rhs);
}
static bool parse(OpAsmParser *parser, OperationState *result) {
@ -926,7 +916,7 @@ protected:
// These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
void buildCastOp(Builder *builder, OperationState *result, Value *source,
Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(const Operation *op, OpAsmPrinter *p);
@ -942,7 +932,7 @@ template <typename ConcreteType, template <typename T> class... Traits>
class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect, Traits...> {
public:
static void build(Builder *builder, OperationState *result, SSAValue *source,
static void build(Builder *builder, OperationState *result, Value *source,
Type destType) {
impl::buildCastOp(builder, result, source, destType);
}

View File

@ -48,7 +48,7 @@ public:
virtual raw_ostream &getStream() const = 0;
/// Print implementations for various things an operation contains.
virtual void printOperand(const SSAValue *value) = 0;
virtual void printOperand(const Value *value) = 0;
/// Print a comma separated list of operands.
template <typename ContainerType>
@ -95,7 +95,7 @@ private:
};
// Make the implementations convenient to use.
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const SSAValue &value) {
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Value &value) {
p.printOperand(&value);
return p;
}
@ -119,7 +119,7 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) {
// even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type& version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, SSAValue &>::value &&
!std::is_convertible<T &, Value &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, AffineMap &>::value,
@ -264,9 +264,8 @@ public:
virtual bool parseOperand(OperandType &result) = 0;
/// Parse a single operation successor and it's operand list.
virtual bool
parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands) = 0;
virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<Value *> &operands) = 0;
/// These are the supported delimiters around operand lists, used by
/// parseOperandList.
@ -311,13 +310,13 @@ public:
/// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure.
virtual bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) = 0;
SmallVectorImpl<Value *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success.
/// This method should be used when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<SSAValue *> &result) {
SmallVectorImpl<Value *> &result) {
for (auto elt : operands)
if (resolveOperand(elt, type, result))
return true;
@ -329,7 +328,7 @@ public:
/// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<SSAValue *> &result) {
SmallVectorImpl<Value *> &result) {
if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) +
" operands present, but expected " +

View File

@ -64,21 +64,20 @@ public:
/// Return the number of operands this operation has.
unsigned getNumOperands() const;
SSAValue *getOperand(unsigned idx);
const SSAValue *getOperand(unsigned idx) const {
Value *getOperand(unsigned idx);
const Value *getOperand(unsigned idx) const {
return const_cast<Operation *>(this)->getOperand(idx);
}
void setOperand(unsigned idx, SSAValue *value);
void setOperand(unsigned idx, Value *value);
// Support non-const operand iteration.
using operand_iterator = OperandIterator<Operation, SSAValue>;
using operand_iterator = OperandIterator<Operation, Value>;
operand_iterator operand_begin();
operand_iterator operand_end();
llvm::iterator_range<operand_iterator> getOperands();
// Support const operand iteration.
using const_operand_iterator =
OperandIterator<const Operation, const SSAValue>;
using const_operand_iterator = OperandIterator<const Operation, const Value>;
const_operand_iterator operand_begin() const;
const_operand_iterator operand_end() const;
llvm::iterator_range<const_operand_iterator> getOperands() const;
@ -87,26 +86,25 @@ public:
unsigned getNumResults() const;
/// Return the indicated result.
SSAValue *getResult(unsigned idx);
const SSAValue *getResult(unsigned idx) const {
Value *getResult(unsigned idx);
const Value *getResult(unsigned idx) const {
return const_cast<Operation *>(this)->getResult(idx);
}
// Support non-const result iteration.
using result_iterator = ResultIterator<Operation, SSAValue>;
using result_iterator = ResultIterator<Operation, Value>;
result_iterator result_begin();
result_iterator result_end();
llvm::iterator_range<result_iterator> getResults();
// Support const result iteration.
using const_result_iterator = ResultIterator<const Operation, const SSAValue>;
using const_result_iterator = ResultIterator<const Operation, const Value>;
const_result_iterator result_begin() const;
const_result_iterator result_end() const;
llvm::iterator_range<const_result_iterator> getResults() const;
// Support for result type iteration.
using result_type_iterator =
ResultTypeIterator<const Operation, const SSAValue>;
using result_type_iterator = ResultTypeIterator<const Operation, const Value>;
result_type_iterator result_type_begin() const;
result_type_iterator result_type_end() const;
llvm::iterator_range<result_type_iterator> getResultTypes() const;

View File

@ -40,8 +40,8 @@ class OpAsmPrinter;
class Pattern;
class RewritePattern;
class StmtBlock;
class SSAValue;
class Type;
class Value;
using BasicBlock = StmtBlock;
/// This is a vector that owns the patterns inside of it.
@ -203,7 +203,7 @@ struct OperationState {
MLIRContext *const context;
Location location;
OperationName name;
SmallVector<SSAValue *, 4> operands;
SmallVector<Value *, 4> operands;
/// Types of the results of this operation.
SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes;
@ -218,7 +218,7 @@ public:
: context(context), location(location), name(name) {}
OperationState(MLIRContext *context, Location location, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
ArrayRef<Value *> operands, ArrayRef<Type> types,
ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors = {})
: context(context), location(location), name(name, context),
@ -227,7 +227,7 @@ public:
attributes(attributes.begin(), attributes.end()),
successors(successors.begin(), successors.end()) {}
void addOperands(ArrayRef<SSAValue *> newOperands) {
void addOperands(ArrayRef<Value *> newOperands) {
assert(successors.empty() &&
"Non successor operands should be added first.");
operands.append(newOperands.begin(), newOperands.end());
@ -247,7 +247,7 @@ public:
attributes.push_back({name, attr});
}
void addSuccessor(StmtBlock *successor, ArrayRef<SSAValue *> succOperands) {
void addSuccessor(StmtBlock *successor, ArrayRef<Value *> succOperands) {
successors.push_back(successor);
// Insert a sentinal operand to mark a barrier between successor operands.
operands.push_back(nullptr);

View File

@ -222,8 +222,8 @@ public:
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those values are dead, this will
/// remove them as well.
void replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
ArrayRef<SSAValue *> valuesToRemoveIfDead = {});
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead = {});
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
@ -237,8 +237,7 @@ public:
/// The result values of the two ops must be the same types. This allows
/// specifying a list of ops that may be removed if dead.
template <typename OpTy, typename... Args>
void replaceOpWithNewOp(Operation *op,
ArrayRef<SSAValue *> valuesToRemoveIfDead,
void replaceOpWithNewOp(Operation *op, ArrayRef<Value *> valuesToRemoveIfDead,
Args... args) {
auto newOp = create<OpTy>(op->getLoc(), args...);
replaceOpWithResultsOfAnotherOp(op, newOp->getOperation(),
@ -254,7 +253,7 @@ public:
/// rewriter should remove if they are dead at this point.
///
void updatedRootInPlace(Operation *op,
ArrayRef<SSAValue *> valuesToRemoveIfDead = {});
ArrayRef<Value *> valuesToRemoveIfDead = {});
protected:
PatternRewriter(MLIRContext *context) : Builder(context) {}
@ -284,9 +283,8 @@ protected:
private:
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void
replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
ArrayRef<SSAValue *> valuesToRemoveIfDead);
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
ArrayRef<Value *> valuesToRemoveIfDead);
};
//===----------------------------------------------------------------------===//

View File

@ -1,154 +0,0 @@
//===- SSAValue.h - Base of the value hierarchy -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines generic SSAValue type and manipulation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SSAVALUE_H
#define MLIR_IR_SSAVALUE_H
#include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Function;
class OperationStmt;
class Operation;
class Statement;
using Instruction = Statement;
using OperationInst = OperationStmt;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
BlockArgument, // Block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};
/// This is the common base class for all values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
class SSAValue : public IRObjectWithUseList {
public:
~SSAValue() {}
SSAValueKind getKind() const { return typeAndKind.getInt(); }
Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(SSAValue *newValue) {
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// Return the function that this SSAValue is defined in.
Function *getFunction();
/// Return the function that this SSAValue is defined in.
const Function *getFunction() const {
return const_cast<SSAValue *>(this)->getFunction();
}
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
OperationInst *getDefiningInst();
const OperationInst *getDefiningInst() const {
return const_cast<SSAValue *>(this)->getDefiningInst();
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *getDefiningStmt();
const OperationStmt *getDefiningStmt() const {
return const_cast<SSAValue *>(this)->getDefiningStmt();
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *getDefiningOperation();
const Operation *getDefiningOperation() const {
return const_cast<SSAValue *>(this)->getDefiningOperation();
}
void print(raw_ostream &os) const;
void dump() const;
protected:
SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
private:
const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
};
inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
value.print(os);
return os;
}
/// This template unifies the implementation logic for CFGValue and MLValue
/// while providing more type-specific APIs when walking use lists etc.
///
/// IROperandTy is the concrete instance of IROperand to use (including
/// substituted template arguments).
/// IROwnerTy is the type of the owner of an IROperandTy type.
/// KindTy is the enum 'kind' discriminator that subclasses want to use.
///
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
class SSAValueImpl : public SSAValue {
public:
// Provide more specific implementations of the base class functionality.
KindTy getKind() const { return (KindTy)SSAValue::getKind(); }
using use_iterator = SSAValueUseIterator<IROperandTy, IROwnerTy>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;
inline use_iterator use_end() const;
/// Returns a range of all uses, which is useful for iterating over all uses.
inline use_range getUses() const;
protected:
SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
};
// Utility functions for iterating through SSAValue uses.
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_begin() const
-> use_iterator {
return use_iterator((IROperandTy *)getFirstUse());
}
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::use_end() const
-> use_iterator {
return use_iterator(nullptr);
}
template <typename IROperandTy, typename IROwnerTy, typename KindTy>
inline auto SSAValueImpl<IROperandTy, IROwnerTy, KindTy>::getUses() const
-> llvm::iterator_range<use_iterator> {
return {use_begin(), use_end()};
}
} // namespace mlir
#endif

View File

@ -22,8 +22,8 @@
#ifndef MLIR_IR_STATEMENT_H
#define MLIR_IR_STATEMENT_H
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
#include "llvm/ADT/ilist_node.h"
@ -84,8 +84,8 @@ public:
// This is a verbose type used by the clone method below.
using OperandMapTy =
DenseMap<const MLValue *, MLValue *, llvm::DenseMapInfo<const MLValue *>,
llvm::detail::DenseMapPair<const MLValue *, MLValue *>>;
DenseMap<const Value *, Value *, llvm::DenseMapInfo<const Value *>,
llvm::detail::DenseMapPair<const Value *, Value *>>;
/// Create a deep copy of this statement, remapping any operands that use
/// values outside of the statement using the map that is provided (leaving
@ -136,12 +136,12 @@ public:
unsigned getNumOperands() const;
MLValue *getOperand(unsigned idx);
const MLValue *getOperand(unsigned idx) const;
void setOperand(unsigned idx, MLValue *value);
Value *getOperand(unsigned idx);
const Value *getOperand(unsigned idx) const;
void setOperand(unsigned idx, Value *value);
// Support non-const operand iteration.
using operand_iterator = OperandIterator<Statement, MLValue>;
using operand_iterator = OperandIterator<Statement, Value>;
operand_iterator operand_begin() { return operand_iterator(this, 0); }
@ -149,14 +149,13 @@ public:
return operand_iterator(this, getNumOperands());
}
/// Returns an iterator on the underlying MLValue's (MLValue *).
/// Returns an iterator on the underlying Value's (Value *).
llvm::iterator_range<operand_iterator> getOperands() {
return {operand_begin(), operand_end()};
}
// Support const operand iteration.
using const_operand_iterator =
OperandIterator<const Statement, const MLValue>;
using const_operand_iterator = OperandIterator<const Statement, const Value>;
const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0);
@ -166,7 +165,7 @@ public:
return const_operand_iterator(this, getNumOperands());
}
/// Returns a const iterator on the underlying MLValue's (MLValue *).
/// Returns a const iterator on the underlying Value's (Value *).
llvm::iterator_range<const_operand_iterator> getOperands() const {
return {operand_begin(), operand_end()};
}

View File

@ -43,7 +43,7 @@ class OperationStmt final
public:
/// Create a new OperationStmt with the specific fields.
static OperationStmt *
create(Location location, OperationName name, ArrayRef<MLValue *> operands,
create(Location location, OperationName name, ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors, MLIRContext *context);
@ -69,16 +69,16 @@ public:
unsigned getNumOperands() const { return numOperands; }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const {
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
void setOperand(unsigned idx, MLValue *value) {
void setOperand(unsigned idx, Value *value) {
return getStmtOperand(idx).set(value);
}
// Support non-const operand iteration.
using operand_iterator = OperandIterator<OperationStmt, MLValue>;
using operand_iterator = OperandIterator<OperationStmt, Value>;
operand_iterator operand_begin() { return operand_iterator(this, 0); }
@ -86,14 +86,14 @@ public:
return operand_iterator(this, getNumOperands());
}
/// Returns an iterator on the underlying MLValue's (MLValue *).
/// Returns an iterator on the underlying Value's (Value *).
llvm::iterator_range<operand_iterator> getOperands() {
return {operand_begin(), operand_end()};
}
// Support const operand iteration.
using const_operand_iterator =
OperandIterator<const OperationStmt, const MLValue>;
OperandIterator<const OperationStmt, const Value>;
const_operand_iterator operand_begin() const {
return const_operand_iterator(this, 0);
@ -103,7 +103,7 @@ public:
return const_operand_iterator(this, getNumOperands());
}
/// Returns a const iterator on the underlying MLValue's (MLValue *).
/// Returns a const iterator on the underlying Value's (Value *).
llvm::iterator_range<const_operand_iterator> getOperands() const {
return {operand_begin(), operand_end()};
}
@ -126,11 +126,11 @@ public:
unsigned getNumResults() const { return numResults; }
MLValue *getResult(unsigned idx) { return &getStmtResult(idx); }
const MLValue *getResult(unsigned idx) const { return &getStmtResult(idx); }
Value *getResult(unsigned idx) { return &getStmtResult(idx); }
const Value *getResult(unsigned idx) const { return &getStmtResult(idx); }
// Support non-const result iteration.
using result_iterator = ResultIterator<OperationStmt, MLValue>;
using result_iterator = ResultIterator<OperationStmt, Value>;
result_iterator result_begin() { return result_iterator(this, 0); }
result_iterator result_end() {
return result_iterator(this, getNumResults());
@ -141,7 +141,7 @@ public:
// Support const result iteration.
using const_result_iterator =
ResultIterator<const OperationStmt, const MLValue>;
ResultIterator<const OperationStmt, const Value>;
const_result_iterator result_begin() const {
return const_result_iterator(this, 0);
}
@ -170,7 +170,7 @@ public:
// Support result type iteration.
using result_type_iterator =
ResultTypeIterator<const OperationStmt, const MLValue>;
ResultTypeIterator<const OperationStmt, const Value>;
result_type_iterator result_type_begin() const {
return result_type_iterator(this, 0);
}
@ -290,15 +290,15 @@ private:
};
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public MLValue {
class ForStmt : public Statement, public Value {
public:
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
static ForStmt *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step);
~ForStmt() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
// since child statements need to be destroyed before the MLValue that this
// since child statements need to be destroyed before the Value that this
// for stmt represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted.
getBody()->clear();
@ -308,8 +308,8 @@ public:
using Statement::getFunction;
/// Operand iterators.
using operand_iterator = OperandIterator<ForStmt, MLValue>;
using const_operand_iterator = OperandIterator<const ForStmt, const MLValue>;
using operand_iterator = OperandIterator<ForStmt, Value>;
using const_operand_iterator = OperandIterator<const ForStmt, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
@ -340,9 +340,9 @@ public:
AffineMap getUpperBoundMap() const { return ubMap; }
/// Set lower bound.
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap map);
void setLowerBound(ArrayRef<Value *> operands, AffineMap map);
/// Set upper bound.
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap map);
void setUpperBound(ArrayRef<Value *> operands, AffineMap map);
/// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap map);
@ -385,11 +385,11 @@ public:
unsigned getNumOperands() const { return operands.size(); }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const {
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
void setOperand(unsigned idx, MLValue *value) {
void setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value);
}
@ -439,10 +439,10 @@ public:
}
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// inheriting from Value class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::ForStmt;
static bool classof(const Value *value) {
return value->getKind() == Value::Kind::ForStmt;
}
private:
@ -475,7 +475,7 @@ public:
AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; }
const MLValue *getOperand(unsigned idx) const {
const Value *getOperand(unsigned idx) const {
return stmt.getOperand(opStart + idx);
}
const StmtOperand &getStmtOperand(unsigned idx) const {
@ -486,15 +486,15 @@ public:
using operand_range = ForStmt::operand_range;
operand_iterator operand_begin() const {
// These are iterators over MLValue *. Not casting away const'ness would
// require the caller to use const MLValue *.
// These are iterators over Value *. Not casting away const'ness would
// require the caller to use const Value *.
return operand_iterator(const_cast<ForStmt *>(&stmt), opStart);
}
operand_iterator operand_end() const {
return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd);
}
/// Returns an iterator on the underlying MLValue's (MLValue *).
/// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
ArrayRef<StmtOperand> getStmtOperands() const {
auto ops = stmt.getStmtOperands();
@ -520,7 +520,7 @@ private:
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement {
public:
static IfStmt *create(Location location, ArrayRef<MLValue *> operands,
static IfStmt *create(Location location, ArrayRef<Value *> operands,
IntegerSet set);
~IfStmt();
@ -556,8 +556,8 @@ public:
//===--------------------------------------------------------------------===//
/// Operand iterators.
using operand_iterator = OperandIterator<IfStmt, MLValue>;
using const_operand_iterator = OperandIterator<const IfStmt, const MLValue>;
using operand_iterator = OperandIterator<IfStmt, Value>;
using const_operand_iterator = OperandIterator<const IfStmt, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
@ -565,11 +565,11 @@ public:
unsigned getNumOperands() const { return operands.size(); }
MLValue *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const MLValue *getOperand(unsigned idx) const {
Value *getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const Value *getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
void setOperand(unsigned idx, MLValue *value) {
void setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value);
}

View File

@ -26,7 +26,6 @@
namespace mlir {
class IfStmt;
class MLValue;
class StmtBlockList;
using CFGFunction = Function;
using MLFunction = Function;
@ -412,7 +411,7 @@ public:
}
private:
using BBUseIterator = SSAValueUseIterator<StmtBlockOperand, OperationStmt>;
using BBUseIterator = ValueUseIterator<StmtBlockOperand, OperationStmt>;
BBUseIterator bbUseIterator;
};

View File

@ -30,7 +30,7 @@ namespace mlir {
class IROperand;
class IROperandOwner;
template <typename OperandType, typename OwnerType> class SSAValueUseIterator;
template <typename OperandType, typename OwnerType> class ValueUseIterator;
class IRObjectWithUseList {
public:
@ -44,7 +44,7 @@ public:
/// Returns true if this value has exactly one use.
inline bool hasOneUse() const;
using use_iterator = SSAValueUseIterator<IROperand, IROperandOwner>;
using use_iterator = ValueUseIterator<IROperand, IROperandOwner>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;
@ -228,33 +228,33 @@ public:
/// An iterator over all uses of a ValueBase.
template <typename OperandType, typename OwnerType>
class SSAValueUseIterator
class ValueUseIterator
: public std::iterator<std::forward_iterator_tag, IROperand> {
public:
SSAValueUseIterator() = default;
explicit SSAValueUseIterator(OperandType *current) : current(current) {}
ValueUseIterator() = default;
explicit ValueUseIterator(OperandType *current) : current(current) {}
OperandType *operator->() const { return current; }
OperandType &operator*() const { return *current; }
OwnerType *getUser() const { return current->getOwner(); }
SSAValueUseIterator &operator++() {
ValueUseIterator &operator++() {
assert(current && "incrementing past end()!");
current = (OperandType *)current->getNextOperandUsingThisValue();
return *this;
}
SSAValueUseIterator operator++(int unused) {
SSAValueUseIterator copy = *this;
ValueUseIterator operator++(int unused) {
ValueUseIterator copy = *this;
++*this;
return copy;
}
friend bool operator==(SSAValueUseIterator lhs, SSAValueUseIterator rhs) {
friend bool operator==(ValueUseIterator lhs, ValueUseIterator rhs) {
return lhs.current == rhs.current;
}
friend bool operator!=(SSAValueUseIterator lhs, SSAValueUseIterator rhs) {
friend bool operator!=(ValueUseIterator lhs, ValueUseIterator rhs) {
return !(lhs == rhs);
}

View File

@ -0,0 +1,198 @@
//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines generic Value type and manipulation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_VALUE_H
#define MLIR_IR_VALUE_H
#include "mlir/IR/Types.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
class Function;
class OperationStmt;
class Operation;
class Statement;
class StmtBlock;
class Value;
using Instruction = Statement;
using OperationInst = OperationStmt;
/// The operand of ML function statement contains a Value.
using StmtOperand = IROperandImpl<Value, Statement>;
/// This is the common base class for all values in the MLIR system,
/// representing a computable value that has a type and a set of users.
///
class Value : public IRObjectWithUseList {
public:
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class Kind {
BlockArgument, // block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};
~Value() {}
Kind getKind() const { return typeAndKind.getInt(); }
Type getType() const { return typeAndKind.getPointer(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(Value *newValue) {
IRObjectWithUseList::replaceAllUsesWith(newValue);
}
/// TODO: move isValidDim/isValidSymbol to a utility library specific to the
/// polyhedral operations.
/// Returns true if the given Value can be used as a dimension id.
bool isValidDim() const;
/// Returns true if the given Value can be used as a symbol.
bool isValidSymbol() const;
/// Return the function that this Value is defined in.
Function *getFunction();
/// Return the function that this Value is defined in.
const Function *getFunction() const {
return const_cast<Value *>(this)->getFunction();
}
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
OperationInst *getDefiningInst();
const OperationInst *getDefiningInst() const {
return const_cast<Value *>(this)->getDefiningInst();
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *getDefiningStmt();
const OperationStmt *getDefiningStmt() const {
return const_cast<Value *>(this)->getDefiningStmt();
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *getDefiningOperation();
const Operation *getDefiningOperation() const {
return const_cast<Value *>(this)->getDefiningOperation();
}
using use_iterator = ValueUseIterator<StmtOperand, Statement>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;
inline use_iterator use_end() const;
/// Returns a range of all uses, which is useful for iterating over all uses.
inline use_range getUses() const;
void print(raw_ostream &os) const;
void dump() const;
protected:
Value(Kind kind, Type type) : typeAndKind(type, kind) {}
private:
const llvm::PointerIntPair<Type, 3, Kind> typeAndKind;
};
inline raw_ostream &operator<<(raw_ostream &os, const Value &value) {
value.print(os);
return os;
}
// Utility functions for iterating through Value uses.
inline auto Value::use_begin() const -> use_iterator {
return use_iterator((StmtOperand *)getFirstUse());
}
inline auto Value::use_end() const -> use_iterator {
return use_iterator(nullptr);
}
inline auto Value::getUses() const -> llvm::iterator_range<use_iterator> {
return {use_begin(), use_end()};
}
/// Block arguments are values.
class BlockArgument : public Value {
public:
static bool classof(const Value *value) {
return value->getKind() == Kind::BlockArgument;
}
/// Return the function that this argument is defined in.
Function *getFunction();
const Function *getFunction() const {
return const_cast<BlockArgument *>(this)->getFunction();
}
StmtBlock *getOwner() { return owner; }
const StmtBlock *getOwner() const { return owner; }
private:
friend class StmtBlock; // For access to private constructor.
BlockArgument(Type type, StmtBlock *owner)
: Value(Value::Kind::BlockArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
StmtBlock *const owner;
};
/// This is a value defined by a result of an operation instruction.
class StmtResult : public Value {
public:
StmtResult(Type type, OperationStmt *owner)
: Value(Value::Kind::StmtResult, type), owner(owner) {}
static bool classof(const Value *value) {
return value->getKind() == Kind::StmtResult;
}
OperationStmt *getOwner() { return owner; }
const OperationStmt *getOwner() const { return owner; }
/// Returns the number of this result.
unsigned getResultNumber() const;
private:
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
OperationStmt *const owner;
};
// TODO(clattner) clean all this up.
using BBArgument = BlockArgument;
using InstResult = StmtResult;
} // namespace mlir
#endif

View File

@ -157,14 +157,14 @@ class Op<string mnemonic, list<OpProperty> props = []> {
//
// static void build(Builder* builder, OperationState* result,
// Type resultType0, Type resultType1, ...,
// SSAValue* arg0, SSAValue* arg1, ...,
// Value arg0, Value arg1, ...,
// Attribute <attr0-name>, Attribute <attr1-name>, ...);
//
// * where the attributes follow the same declaration order as in the op.
//
// static void build(Builder* builder, OperationState* result,
// ArrayRef<Type> resultTypes,
// ArrayRef<SSAValue*> args,
// ArrayRef<Value> args,
// ArrayRef<NamedAttribute> attributes);
code builder = ?;

View File

@ -30,7 +30,6 @@
namespace mlir {
class AffineMap;
class Builder;
class MLValue;
class StandardOpsDialect : public Dialect {
public:
@ -48,8 +47,8 @@ class AddFOp
: public BinaryOp<AddFOp, OpTrait::ResultsAreFloatLike,
OpTrait::IsCommutative, OpTrait::HasNoSideEffect> {
public:
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs);
static void build(Builder *builder, OperationState *result, Value *lhs,
Value *rhs);
static StringRef getOperationName() { return "addf"; }
@ -116,7 +115,7 @@ public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
MemRefType memrefType, ArrayRef<Value *> operands = {});
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@ -140,7 +139,7 @@ public:
static StringRef getOperationName() { return "call"; }
static void build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands);
ArrayRef<Value *> operands);
Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee").getValue();
@ -169,11 +168,11 @@ class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
public:
static StringRef getOperationName() { return "call_indirect"; }
static void build(Builder *builder, OperationState *result, SSAValue *callee,
ArrayRef<SSAValue *> operands);
static void build(Builder *builder, OperationState *result, Value *callee,
ArrayRef<Value *> operands);
const SSAValue *getCallee() const { return getOperand(0); }
SSAValue *getCallee() { return getOperand(0); }
const Value *getCallee() const { return getOperand(0); }
Value *getCallee() { return getOperand(0); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
@ -240,7 +239,7 @@ public:
static CmpIPredicate getPredicateByName(StringRef name);
static void build(Builder *builder, OperationState *result, CmpIPredicate,
SSAValue *lhs, SSAValue *rhs);
Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
@ -263,14 +262,14 @@ private:
class DeallocOp
: public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public:
SSAValue *getMemRef() { return getOperand(); }
const SSAValue *getMemRef() const { return getOperand(); }
void setMemRef(SSAValue *value) { setOperand(value); }
Value *getMemRef() { return getOperand(); }
const Value *getMemRef() const { return getOperand(); }
void setMemRef(Value *value) { setOperand(value); }
static StringRef getOperationName() { return "dealloc"; }
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, SSAValue *memref);
static void build(Builder *builder, OperationState *result, Value *memref);
bool verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@ -292,7 +291,7 @@ class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
public:
static void build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index);
Value *memrefOrTensor, unsigned index);
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
@ -354,15 +353,15 @@ private:
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
static void build(Builder *builder, OperationState *result,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
SSAValue *numElements, SSAValue *tagMemRef,
ArrayRef<SSAValue *> tagIndices, SSAValue *stride = nullptr,
SSAValue *elementsPerStride = nullptr);
static void build(Builder *builder, OperationState *result, Value *srcMemRef,
ArrayRef<Value *> srcIndices, Value *destMemRef,
ArrayRef<Value *> destIndices, Value *numElements,
Value *tagMemRef, ArrayRef<Value *> tagIndices,
Value *stride = nullptr,
Value *elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
const SSAValue *getSrcMemRef() const { return getOperand(0); }
const Value *getSrcMemRef() const { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() const {
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
@ -375,7 +374,7 @@ public:
}
// Returns the destination MemRefType for this DMA operations.
const SSAValue *getDstMemRef() const {
const Value *getDstMemRef() const {
return getOperand(1 + getSrcMemRefRank());
}
// Returns the rank (number of indices) of the destination MemRefType.
@ -398,12 +397,12 @@ public:
}
// Returns the number of elements being transferred by this DMA operation.
const SSAValue *getNumElements() const {
const Value *getNumElements() const {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
const SSAValue *getTagMemRef() const {
const Value *getTagMemRef() const {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
@ -453,21 +452,21 @@ public:
1 + 1 + getTagMemRefRank();
}
SSAValue *getStride() {
Value *getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
const SSAValue *getStride() const {
const Value *getStride() const {
return const_cast<DmaStartOp *>(this)->getStride();
}
SSAValue *getNumElementsPerStride() {
Value *getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
const SSAValue *getNumElementsPerStride() const {
const Value *getNumElementsPerStride() const {
return const_cast<DmaStartOp *>(this)->getNumElementsPerStride();
}
@ -493,15 +492,14 @@ protected:
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
static void build(Builder *builder, OperationState *result,
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices,
SSAValue *numElements);
static void build(Builder *builder, OperationState *result, Value *tagMemRef,
ArrayRef<Value *> tagIndices, Value *numElements);
static StringRef getOperationName() { return "dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
const SSAValue *getTagMemRef() const { return getOperand(0); }
SSAValue *getTagMemRef() { return getOperand(0); }
const Value *getTagMemRef() const { return getOperand(0); }
Value *getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator>
@ -516,7 +514,7 @@ public:
}
// Returns the number of elements transferred in the associated DMA operation.
const SSAValue *getNumElements() const {
const Value *getNumElements() const {
return getOperand(1 + getTagMemRefRank());
}
@ -545,11 +543,11 @@ class ExtractElementOp
: public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
public:
static void build(Builder *builder, OperationState *result,
SSAValue *aggregate, ArrayRef<SSAValue *> indices = {});
static void build(Builder *builder, OperationState *result, Value *aggregate,
ArrayRef<Value *> indices = {});
SSAValue *getAggregate() { return getOperand(0); }
const SSAValue *getAggregate() const { return getOperand(0); }
Value *getAggregate() { return getOperand(0); }
const Value *getAggregate() const { return getOperand(0); }
llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
@ -583,12 +581,12 @@ class LoadOp
: public Op<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices = {});
static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices = {});
SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); }
Value *getMemRef() { return getOperand(0); }
const Value *getMemRef() const { return getOperand(0); }
void setMemRef(Value *value) { setOperand(0, value); }
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}
@ -705,19 +703,18 @@ class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
static StringRef getOperationName() { return "select"; }
static void build(Builder *builder, OperationState *result,
SSAValue *condition, SSAValue *trueValue,
SSAValue *falseValue);
static void build(Builder *builder, OperationState *result, Value *condition,
Value *trueValue, Value *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
SSAValue *getCondition() { return getOperand(0); }
const SSAValue *getCondition() const { return getOperand(0); }
SSAValue *getTrueValue() { return getOperand(1); }
const SSAValue *getTrueValue() const { return getOperand(1); }
SSAValue *getFalseValue() { return getOperand(2); }
const SSAValue *getFalseValue() const { return getOperand(2); }
Value *getCondition() { return getOperand(0); }
const Value *getCondition() const { return getOperand(0); }
Value *getTrueValue() { return getOperand(1); }
const Value *getTrueValue() const { return getOperand(1); }
Value *getFalseValue() { return getOperand(2); }
const Value *getFalseValue() const { return getOperand(2); }
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
@ -742,15 +739,15 @@ class StoreOp
public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result,
SSAValue *valueToStore, SSAValue *memref,
ArrayRef<SSAValue *> indices = {});
Value *valueToStore, Value *memref,
ArrayRef<Value *> indices = {});
SSAValue *getValueToStore() { return getOperand(0); }
const SSAValue *getValueToStore() const { return getOperand(0); }
Value *getValueToStore() { return getOperand(0); }
const Value *getValueToStore() const { return getOperand(0); }
SSAValue *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); }
Value *getMemRef() { return getOperand(1); }
const Value *getMemRef() const { return getOperand(1); }
void setMemRef(Value *value) { setOperand(1, value); }
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}

View File

@ -98,26 +98,24 @@ public:
static StringRef getOperationName() { return "vector_transfer_read"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
static void build(Builder *builder, OperationState *result,
VectorType vectorType, SSAValue *srcMemRef,
ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap,
Optional<SSAValue *> paddingValue = None);
VectorType vectorType, Value *srcMemRef,
ArrayRef<Value *> srcIndices, AffineMap permutationMap,
Optional<Value *> paddingValue = None);
VectorType getResultType() const {
return getResult()->getType().cast<VectorType>();
}
SSAValue *getVector() { return getResult(); }
const SSAValue *getVector() const { return getResult(); }
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const SSAValue *getMemRef() const {
return getOperand(Offsets::MemRefOffset);
}
Value *getVector() { return getResult(); }
const Value *getVector() const { return getResult(); }
Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
VectorType getVectorType() const { return getResultType(); }
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}
llvm::iterator_range<Operation::operand_iterator> getIndices();
llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
Optional<SSAValue *> getPaddingValue();
Optional<const SSAValue *> getPaddingValue() const;
Optional<Value *> getPaddingValue();
Optional<const Value *> getPaddingValue() const;
AffineMap getPermutationMap() const;
static bool parse(OpAsmParser *parser, OperationState *result);
@ -169,20 +167,16 @@ class VectorTransferWriteOp
public:
static StringRef getOperationName() { return "vector_transfer_write"; }
static StringRef getPermutationMapAttrName() { return "permutation_map"; }
static void build(Builder *builder, OperationState *result,
SSAValue *srcVector, SSAValue *dstMemRef,
ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap);
SSAValue *getVector() { return getOperand(Offsets::VectorOffset); }
const SSAValue *getVector() const {
return getOperand(Offsets::VectorOffset);
}
static void build(Builder *builder, OperationState *result, Value *srcVector,
Value *dstMemRef, ArrayRef<Value *> dstIndices,
AffineMap permutationMap);
Value *getVector() { return getOperand(Offsets::VectorOffset); }
const Value *getVector() const { return getOperand(Offsets::VectorOffset); }
VectorType getVectorType() const {
return getVector()->getType().cast<VectorType>();
}
SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const SSAValue *getMemRef() const {
return getOperand(Offsets::MemRefOffset);
}
Value *getMemRef() { return getOperand(Offsets::MemRefOffset); }
const Value *getMemRef() const { return getOperand(Offsets::MemRefOffset); }
MemRefType getMemRefType() const {
return getMemRef()->getType().cast<MemRefType>();
}
@ -212,8 +206,8 @@ class VectorTypeCastOp
: public Op<VectorTypeCastOp, OpTrait::OneOperand, OpTrait::OneResult> {
public:
static StringRef getOperationName() { return "vector_type_cast"; }
static void build(Builder *builder, OperationState *result,
SSAValue *srcVector, Type dstType);
static void build(Builder *builder, OperationState *result, Value *srcVector,
Type dstType);
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;

View File

@ -35,10 +35,9 @@ namespace mlir {
class ForStmt;
class FuncBuilder;
class Location;
class MLValue;
class Module;
class OperationStmt;
class SSAValue;
class Function;
using CFGFunction = Function;
@ -52,12 +51,12 @@ using CFGFunction = Function;
/// Returns true on success and false if the replacement is not possible
/// (whenever a memref is used as an operand in a non-deferencing scenario). See
/// comments at function definition for an example.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
ArrayRef<SSAValue *> extraIndices = {},
bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap::Null(),
ArrayRef<SSAValue *> extraOperands = {},
ArrayRef<Value *> extraOperands = {},
const Statement *domStmtFilter = nullptr);
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
@ -69,9 +68,9 @@ bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
/// parameter 'results'. Returns the affine apply op created.
OperationStmt *
createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands,
ArrayRef<Value *> operands,
ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> *results);
SmallVectorImpl<Value *> *results);
/// Given an operation statement, inserts a new single affine apply operation,
/// that is exclusively used by this operation statement, and that provides all
@ -104,7 +103,7 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt);
/// Forward substitutes results from 'AffineApplyOp' into any users which
/// are also AffineApplyOps.
// NOTE: This method may modify users of results of this operation.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions.
// TODO(mlir-team): extend this for Value/ CFGFunctions.
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.

View File

@ -489,11 +489,11 @@ bool mlir::getFlattenedAffineExprs(
// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
// the AffineApplyOp into any user AffineApplyOps.
void mlir::getReachableAffineApplyOps(
ArrayRef<MLValue *> operands,
ArrayRef<Value *> operands,
SmallVectorImpl<OperationStmt *> &affineApplyOps) {
struct State {
// The ssa value for this node in the DFS traversal.
MLValue *value;
Value *value;
// The operand index of 'value' to explore next during DFS traversal.
unsigned operandIndex;
};
@ -557,8 +557,8 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
FlatAffineConstraints *domain) {
SmallVector<MLValue *, 4> indices(forStmts.begin(), forStmts.end());
// Reset while associated MLValues in 'indices' to the domain.
SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
// Reset while associated Values in 'indices' to the domain.
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forStmt : forStmts) {
// Add constraints from forStmt's bounds.
@ -583,10 +583,10 @@ static bool getStmtIndexSet(const Statement *stmt,
return getIndexSet(loops, indexSet);
}
// ValuePositionMap manages the mapping from MLValues which represent dimension
// ValuePositionMap manages the mapping from Values which represent dimension
// and symbol identifiers from 'src' and 'dst' access functions to positions
// in new space where some MLValues are kept separate (using addSrc/DstValue)
// and some MLValues are merged (addSymbolValue).
// in new space where some Values are kept separate (using addSrc/DstValue)
// and some Values are merged (addSymbolValue).
// Position lookups return the absolute position in the new space which
// has the following format:
//
@ -595,7 +595,7 @@ static bool getStmtIndexSet(const Statement *stmt,
// Note: access function non-IV dimension identifiers (that have 'dimension'
// positions in the access function position space) are assigned as symbols
// in the output position space. Convienience access functions which lookup
// an MLValue in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
// the common case of resolving positions for all access function operands.
//
// TODO(andydavis) Generalize this: could take a template parameter for
@ -603,25 +603,25 @@ static bool getStmtIndexSet(const Statement *stmt,
// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
class ValuePositionMap {
public:
void addSrcValue(const MLValue *value) {
void addSrcValue(const Value *value) {
if (addValueAt(value, &srcDimPosMap, numSrcDims))
++numSrcDims;
}
void addDstValue(const MLValue *value) {
void addDstValue(const Value *value) {
if (addValueAt(value, &dstDimPosMap, numDstDims))
++numDstDims;
}
void addSymbolValue(const MLValue *value) {
void addSymbolValue(const Value *value) {
if (addValueAt(value, &symbolPosMap, numSymbols))
++numSymbols;
}
unsigned getSrcDimOrSymPos(const MLValue *value) const {
unsigned getSrcDimOrSymPos(const Value *value) const {
return getDimOrSymPos(value, srcDimPosMap, 0);
}
unsigned getDstDimOrSymPos(const MLValue *value) const {
unsigned getDstDimOrSymPos(const Value *value) const {
return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
}
unsigned getSymPos(const MLValue *value) const {
unsigned getSymPos(const Value *value) const {
auto it = symbolPosMap.find(value);
assert(it != symbolPosMap.end());
return numSrcDims + numDstDims + it->second;
@ -633,8 +633,7 @@ public:
unsigned getNumSymbols() const { return numSymbols; }
private:
bool addValueAt(const MLValue *value,
DenseMap<const MLValue *, unsigned> *posMap,
bool addValueAt(const Value *value, DenseMap<const Value *, unsigned> *posMap,
unsigned position) {
auto it = posMap->find(value);
if (it == posMap->end()) {
@ -643,8 +642,8 @@ private:
}
return false;
}
unsigned getDimOrSymPos(const MLValue *value,
const DenseMap<const MLValue *, unsigned> &dimPosMap,
unsigned getDimOrSymPos(const Value *value,
const DenseMap<const Value *, unsigned> &dimPosMap,
unsigned dimPosOffset) const {
auto it = dimPosMap.find(value);
if (it != dimPosMap.end()) {
@ -658,25 +657,25 @@ private:
unsigned numSrcDims = 0;
unsigned numDstDims = 0;
unsigned numSymbols = 0;
DenseMap<const MLValue *, unsigned> srcDimPosMap;
DenseMap<const MLValue *, unsigned> dstDimPosMap;
DenseMap<const MLValue *, unsigned> symbolPosMap;
DenseMap<const Value *, unsigned> srcDimPosMap;
DenseMap<const Value *, unsigned> dstDimPosMap;
DenseMap<const Value *, unsigned> symbolPosMap;
};
// Builds a map from MLValue to identifier position in a new merged identifier
// Builds a map from Value to identifier position in a new merged identifier
// list, which is the result of merging dim/symbol lists from src/dst
// iteration domains. The format of the new merged list is as follows:
//
// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers]
//
// This method populates 'valuePosMap' with mappings from operand MLValues in
// This method populates 'valuePosMap' with mappings from operand Values in
// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
// to the position of these values in the merged list.
static void buildDimAndSymbolPositionMaps(
const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap) {
auto updateValuePosMap = [&](ArrayRef<MLValue *> values, bool isSrc) {
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isa<ForStmt>(values[i]))
@ -688,7 +687,7 @@ static void buildDimAndSymbolPositionMaps(
}
};
SmallVector<MLValue *, 4> srcValues, destValues;
SmallVector<Value *, 4> srcValues, destValues;
srcDomain.getIdValues(&srcValues);
dstDomain.getIdValues(&destValues);
@ -702,17 +701,10 @@ static void buildDimAndSymbolPositionMaps(
updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
}
static unsigned getPos(const DenseMap<const MLValue *, unsigned> &posMap,
const MLValue *value) {
auto it = posMap.find(value);
assert(it != posMap.end());
return it->second;
}
// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
// 'dependenceDomain'.
// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
// srcDomain/dstDomain MLValue maps.
// srcDomain/dstDomain Value maps.
static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain,
const ValuePositionMap &valuePosMap,
@ -790,10 +782,10 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
unsigned numResults = srcMap.getNumResults();
unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
ArrayRef<MLValue *> srcOperands = srcAccessMap.getOperands();
ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
ArrayRef<MLValue *> dstOperands = dstAccessMap.getOperands();
ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
std::vector<SmallVector<int64_t, 8>> destFlatExprs;
@ -848,7 +840,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
}
// Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const MLValue *> operands) {
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (isa<ForStmt>(operands[i]))
continue;
@ -1095,7 +1087,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// upper/lower loop bounds for each ForStmt in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// MLValues from access functions and iteration domains to their position
// Values from access functions and iteration domains to their position
// in the merged constraint system built by this method.
//
// This method builds a constraint system with the following column format:
@ -1202,7 +1194,7 @@ bool mlir::checkMemrefAccessDependence(
return false;
}
// Build dim and symbol position maps for each access from access operand
// MLValue to position in merged contstraint system.
// Value to position in merged contstraint system.
ValuePositionMap valuePosMap;
buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
dstAccessMap, &valuePosMap);

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
@ -238,23 +237,23 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
AffineValueMap::AffineValueMap(const AffineApplyOp &op)
: map(op.getAffineMap()) {
for (auto *operand : op.getOperands())
operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand)));
operands.push_back(const_cast<Value *>(operand));
for (unsigned i = 0, e = op.getNumResults(); i < e; i++)
results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i))));
results.push_back(const_cast<Value *>(op.getResult(i)));
}
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands)
: map(map) {
for (MLValue *operand : operands) {
for (Value *operand : operands) {
this->operands.push_back(operand);
}
}
void AffineValueMap::reset(AffineMap map, ArrayRef<MLValue *> operands) {
void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands) {
this->operands.clear();
this->results.clear();
this->map.reset(map);
for (MLValue *operand : operands) {
for (Value *operand : operands) {
this->operands.push_back(operand);
}
}
@ -275,7 +274,7 @@ void AffineValueMap::forwardSubstituteSingle(const AffineApplyOp &inputOp,
// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
static bool findIndex(MLValue *valueToMatch, ArrayRef<MLValue *> valuesToSearch,
static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
unsigned indexStart, unsigned *indexOfMatch) {
unsigned size = valuesToSearch.size();
for (unsigned i = indexStart; i < size; ++i) {
@ -324,8 +323,7 @@ void AffineValueMap::forwardSubstitute(
for (unsigned j = 0; j < inputNumResults; ++j) {
if (!inputResultsToSubstitute[j])
continue;
if (operands[i] ==
cast<MLValue>(const_cast<SSAValue *>(inputOp.getResult(j)))) {
if (operands[i] == const_cast<Value *>(inputOp.getResult(j))) {
currOperandToInputResult[i] = j;
inputResultsUsed.insert(j);
}
@ -365,7 +363,7 @@ void AffineValueMap::forwardSubstitute(
}
// Build new output operands list and map update.
SmallVector<MLValue *, 4> outputOperands;
SmallVector<Value *, 4> outputOperands;
unsigned outputOperandPosition = 0;
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults());
@ -385,8 +383,7 @@ void AffineValueMap::forwardSubstitute(
if (inputPositionsUsed.count(i) == 0)
continue;
// Check if input operand has a dup in current operand list.
auto *inputOperand =
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
unsigned outputIndex;
if (findIndex(inputOperand, outputOperands, /*indexStart=*/0,
&outputIndex)) {
@ -418,8 +415,7 @@ void AffineValueMap::forwardSubstitute(
continue;
unsigned inputSymbolPosition = i - inputNumDims;
// Check if input operand has a dup in current operand list.
auto *inputOperand =
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
auto *inputOperand = const_cast<Value *>(inputOp.getOperand(i));
// Find output operand index of 'inputOperand' dup.
unsigned outputIndex;
// Start at index 'outputNumDims' so that only symbol operands are searched.
@ -451,7 +447,7 @@ inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
/// This method uses the invariant that operands are always positionally aligned
/// with the AffineDimExpr in the underlying AffineMap.
bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
unsigned index;
findIndex(value, operands, /*indexStart=*/0, &index);
auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
@ -460,12 +456,12 @@ bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
return expr.isFunctionOfDim(index);
}
SSAValue *AffineValueMap::getOperand(unsigned i) const {
return static_cast<SSAValue *>(operands[i]);
Value *AffineValueMap::getOperand(unsigned i) const {
return static_cast<Value *>(operands[i]);
}
ArrayRef<MLValue *> AffineValueMap::getOperands() const {
return ArrayRef<MLValue *>(operands);
ArrayRef<Value *> AffineValueMap::getOperands() const {
return ArrayRef<Value *>(operands);
}
AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
@ -546,7 +542,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
unsigned newNumReservedCols,
unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) {
ArrayRef<Value *> idArgs) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column");
numReservedCols = newNumReservedCols;
@ -570,7 +566,7 @@ void FlatAffineConstraints::reset(unsigned numReservedInequalities,
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals,
ArrayRef<MLValue *> idArgs) {
ArrayRef<Value *> idArgs) {
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
newNumSymbols, newNumLocals, idArgs);
}
@ -597,17 +593,17 @@ void FlatAffineConstraints::addLocalId(unsigned pos) {
addId(IdKind::Local, pos);
}
void FlatAffineConstraints::addDimId(unsigned pos, MLValue *id) {
void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
addId(IdKind::Dimension, pos, id);
}
void FlatAffineConstraints::addSymbolId(unsigned pos, MLValue *id) {
void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
addId(IdKind::Symbol, pos, id);
}
/// Adds a dimensional identifier. The added column is initialized to
/// zero.
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, MLValue *id) {
void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
if (kind == IdKind::Dimension) {
assert(pos <= getNumDimIds());
} else if (kind == IdKind::Symbol) {
@ -755,7 +751,7 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
// Dims and symbols.
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
unsigned loc;
bool ret = findId(*cast<MLValue>(vMap->getOperand(i)), &loc);
bool ret = findId(*vMap->getOperand(i), &loc);
assert(ret && "value map's id can't be found");
(void)ret;
// We need to negate 'eq[r]' since the newly added dimension is going to
@ -1231,7 +1227,7 @@ void FlatAffineConstraints::addUpperBound(ArrayRef<int64_t> expr,
}
}
bool FlatAffineConstraints::findId(const MLValue &id, unsigned *pos) const {
bool FlatAffineConstraints::findId(const Value &id, unsigned *pos) const {
unsigned i = 0;
for (const auto &mayBeId : ids) {
if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
@ -1253,8 +1249,8 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
unsigned pos;
// Pre-condition for this method.
if (!findId(*cast<MLValue>(&forStmt), &pos)) {
assert(0 && "MLValue not found");
if (!findId(forStmt, &pos)) {
assert(0 && "Value not found");
return false;
}
@ -1270,7 +1266,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
unsigned loc;
if (!findId(*operand, &loc)) {
if (operand->isValidSymbol()) {
addSymbolId(getNumSymbolIds(), const_cast<MLValue *>(operand));
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.
if (auto *opStmt = operand->getDefiningStmt()) {
@ -1279,7 +1275,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
}
}
} else {
addDimId(getNumDimIds(), const_cast<MLValue *>(operand));
addDimId(getNumDimIds(), const_cast<Value *>(operand));
loc = getNumDimIds() - 1;
}
}
@ -1352,7 +1348,7 @@ void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
/// Sets the specified identifer to a constant value; asserts if the id is not
/// found.
void FlatAffineConstraints::setIdToConstant(const MLValue &id, int64_t val) {
void FlatAffineConstraints::setIdToConstant(const Value &id, int64_t val) {
unsigned pos;
if (!findId(id, &pos))
// This is a pre-condition for this method.
@ -1572,7 +1568,7 @@ void FlatAffineConstraints::print(raw_ostream &os) const {
if (ids[i] == None)
os << "None ";
else
os << "MLValue ";
os << "Value ";
}
os << " const)\n";
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
@ -1779,7 +1775,7 @@ void FlatAffineConstraints::FourierMotzkinEliminate(
unsigned newNumDims = dimsSymbols.first;
unsigned newNumSymbols = dimsSymbols.second;
SmallVector<Optional<MLValue *>, 8> newIds;
SmallVector<Optional<Value *>, 8> newIds;
newIds.reserve(numIds - 1);
newIds.insert(newIds.end(), ids.begin(), ids.begin() + pos);
newIds.insert(newIds.end(), ids.begin() + pos + 1, ids.end());
@ -1942,7 +1938,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
normalizeConstraintsByGCD();
}
void FlatAffineConstraints::projectOut(MLValue *id) {
void FlatAffineConstraints::projectOut(Value *id) {
unsigned pos;
bool ret = findId(*id, &pos);
assert(ret);

View File

@ -70,7 +70,7 @@ bool DominanceInfo::properlyDominates(const Instruction *a,
}
/// Return true if value A properly dominates instruction B.
bool DominanceInfo::properlyDominates(const SSAValue *a, const Instruction *b) {
bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
if (auto *aInst = a->getDefiningInst())
return properlyDominates(aInst, b);

View File

@ -124,14 +124,14 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
return tripCountExpr.getLargestKnownDivisor();
}
bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<MLValue *>(&index)}, affineApplyOps);
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
if (affineApplyOps.empty()) {
// Pointer equality test because of MLValue pointer semantics.
// Pointer equality test because of Value pointer semantics.
return &index != &iv;
}
@ -155,13 +155,13 @@ bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
}
assert(idx < std::numeric_limits<unsigned>::max());
return !AffineValueMap(*composeOp)
.isFunctionOf(idx, &const_cast<MLValue &>(iv));
.isFunctionOf(idx, &const_cast<Value &>(iv));
}
llvm::DenseSet<const MLValue *>
mlir::getInvariantAccesses(const MLValue &iv,
llvm::ArrayRef<const MLValue *> indices) {
llvm::DenseSet<const MLValue *> res;
llvm::DenseSet<const Value *>
mlir::getInvariantAccesses(const Value &iv,
llvm::ArrayRef<const Value *> indices) {
llvm::DenseSet<const Value *> res;
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
auto *val = indices[idx];
if (isAccessInvariant(iv, *val)) {
@ -191,7 +191,7 @@ mlir::getInvariantAccesses(const MLValue &iv,
///
// TODO(ntv): check strides.
template <typename LoadOrStoreOp>
static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
static bool isContiguousAccess(const Value &iv, const LoadOrStoreOp &memoryOp,
unsigned fastestVaryingDim) {
static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value ||
std::is_same<LoadOrStoreOp, StoreOp>::value,
@ -220,7 +220,7 @@ static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
if (fastestVaryingDim == (numIndices - 1) - d++) {
continue;
}
if (!isAccessInvariant(iv, cast<MLValue>(*index))) {
if (!isAccessInvariant(iv, *index)) {
return false;
}
}
@ -316,7 +316,7 @@ bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
// outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const MLValue *result = opStmt->getResult(i);
const Value *result = opStmt->getResult(i);
for (const StmtOperand &use : result->getUses()) {
// If an ancestor statement doesn't lie in the block of forStmt, there
// is no shift to check.

View File

@ -70,7 +70,7 @@ static void addMemRefAccessIndices(
MemRefType memrefType, MemRefAccess *access) {
access->indices.reserve(memrefType.getRank());
for (auto *index : opIndices) {
access->indices.push_back(cast<MLValue>(const_cast<SSAValue *>(index)));
access->indices.push_back(const_cast<mlir::Value *>(index));
}
}
@ -79,13 +79,13 @@ static void getMemRefAccess(const OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) {
access->opStmt = loadOrStoreOpStmt;
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = cast<MLValue>(loadOp->getMemRef());
access->memref = loadOp->getMemRef();
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
access);
} else {
assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->memref = cast<MLValue>(storeOp->getMemRef());
access->memref = storeOp->getMemRef();
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
access);
}

View File

@ -150,21 +150,21 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp;
unsigned rank;
SmallVector<MLValue *, 4> indices;
SmallVector<Value *, 4> indices;
if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
rank = loadOp->getMemRefType().getRank();
for (auto *index : loadOp->getIndices()) {
indices.push_back(cast<MLValue>(index));
indices.push_back(index);
}
region->memref = cast<MLValue>(loadOp->getMemRef());
region->memref = loadOp->getMemRef();
region->setWrite(false);
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
rank = storeOp->getMemRefType().getRank();
for (auto *index : storeOp->getIndices()) {
indices.push_back(cast<MLValue>(index));
indices.push_back(index);
}
region->memref = cast<MLValue>(storeOp->getMemRef());
region->memref = storeOp->getMemRef();
region->setWrite(true);
} else {
return false;
@ -201,7 +201,7 @@ bool mlir::getMemRefRegion(OperationStmt *opStmt, unsigned loopDepth,
return false;
} else {
// Has to be a valid symbol.
auto *symbol = cast<MLValue>(accessValueMap.getOperand(i));
auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
if (auto *opStmt = symbol->getDefiningStmt()) {
@ -405,7 +405,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// Solve for src IVs in terms of dst IVs, symbols and constants.
SmallVector<AffineMap, 4> srcIvMaps(srcLoopNestSize, AffineMap::Null());
std::vector<SmallVector<MLValue *, 2>> srcIvOperands(srcLoopNestSize);
std::vector<SmallVector<Value *, 2>> srcIvOperands(srcLoopNestSize);
for (unsigned i = 0; i < srcLoopNestSize; ++i) {
// Skip IVs which are greater than requested loop depth.
if (i >= srcLoopDepth) {
@ -442,7 +442,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
srcIvOperands[i].push_back(dstLoopNest[dimId - 1]);
}
// TODO(andydavis) Add symbols from the access function. Ideally, we
// should be able to query the constaint system for the MLValue associated
// should be able to query the constaint system for the Value associated
// with a symbol identifiers in 'nonZeroSymbolIds'.
}
@ -454,7 +454,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const MLValue *, MLValue *> operandMap;
DenseMap<const Value *, Value *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.

View File

@ -108,7 +108,7 @@ static AffineMap makePermutationMap(
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<SSAValue, MLValue>(), indices);
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {

View File

@ -277,7 +277,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
/// Walk all of the code in this MLFunc and verify that the operands of any
/// operations are properly dominated by their definitions.
bool MLFuncVerifier::verifyDominance() {
using HashTable = llvm::ScopedHashTable<const SSAValue *, bool>;
using HashTable = llvm::ScopedHashTable<const Value *, bool>;
HashTable liveValues;
HashTable::ScopeTy topScope(liveValues);

View File

@ -38,7 +38,6 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); }
@ -967,7 +966,7 @@ public:
void printFunctionAttributes(const Function *func) {
return ModulePrinter::printFunctionAttributes(func);
}
void printOperand(const SSAValue *value) { printValueID(value); }
void printOperand(const Value *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) {
@ -977,7 +976,7 @@ public:
enum { nameSentinel = ~0U };
protected:
void numberValueID(const SSAValue *value) {
void numberValueID(const Value *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
SmallString<32> specialNameBuffer;
@ -1004,7 +1003,7 @@ protected:
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case SSAValueKind::BlockArgument:
case Value::Kind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->getFunction())
@ -1015,12 +1014,12 @@ protected:
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::StmtResult:
case Value::Kind::StmtResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::ForStmt:
case Value::Kind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
}
@ -1052,7 +1051,7 @@ protected:
}
}
void printValueID(const SSAValue *value, bool printResultNo = true) const {
void printValueID(const Value *value, bool printResultNo = true) const {
int resultNo = -1;
auto lookupValue = value;
@ -1093,8 +1092,8 @@ protected:
private:
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
DenseMap<const SSAValue *, unsigned> valueIDs;
DenseMap<const SSAValue *, StringRef> valueNames;
DenseMap<const Value *, unsigned> valueIDs;
DenseMap<const Value *, StringRef> valueNames;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
@ -1135,7 +1134,7 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
os << "\"(";
interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printValueID(value); });
[&](const Value *value) { printValueID(value); });
os << ')';
auto attrs = op->getAttrs();
@ -1144,16 +1143,15 @@ void FunctionPrinter::printDefaultOp(const Operation *op) {
// Print the type signature of the operation.
os << " : (";
interleaveComma(op->getOperands(),
[&](const SSAValue *value) { printType(value->getType()); });
[&](const Value *value) { printType(value->getType()); });
os << ") -> ";
if (op->getNumResults() == 1) {
printType(op->getResult(0)->getType());
} else {
os << '(';
interleaveComma(op->getResults(), [&](const SSAValue *result) {
printType(result->getType());
});
interleaveComma(op->getResults(),
[&](const Value *result) { printType(result->getType()); });
os << ')';
}
}
@ -1297,11 +1295,10 @@ void CFGFunctionPrinter::printBranchOperands(const Range &range) {
os << '(';
interleaveComma(range,
[this](const SSAValue *operand) { printValueID(operand); });
[this](const Value *operand) { printValueID(operand); });
os << " : ";
interleaveComma(range, [this](const SSAValue *operand) {
printType(operand->getType());
});
interleaveComma(
range, [this](const Value *operand) { printType(operand->getType()); });
os << ')';
}
@ -1576,20 +1573,20 @@ void IntegerSet::print(raw_ostream &os) const {
ModulePrinter(os, state).printIntegerSet(*this);
}
void SSAValue::print(raw_ostream &os) const {
void Value::print(raw_ostream &os) const {
switch (getKind()) {
case SSAValueKind::BlockArgument:
case Value::Kind::BlockArgument:
// TODO: Improve this.
os << "<block argument>\n";
return;
case SSAValueKind::StmtResult:
case Value::Kind::StmtResult:
return getDefiningStmt()->print(os);
case SSAValueKind::ForStmt:
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->print(os);
}
}
void SSAValue::dump() const { print(llvm::errs()); }
void Value::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
auto *function = getFunction();

View File

@ -281,7 +281,7 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
// If we are supposed to insert before a specific block, do so, otherwise add
// the block to the end of the function.
if (insertBefore)
function->getBlocks().insert(CFGFunction::iterator(insertBefore), b);
function->getBlocks().insert(Function::iterator(insertBefore), b);
else
function->push_back(b);
@ -291,16 +291,9 @@ BasicBlock *CFGFuncBuilder::createBlock(BasicBlock *insertBefore) {
/// Create an operation given the fields represented as an OperationState.
OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
SmallVector<CFGValue *, 8> operands;
operands.reserve(state.operands.size());
// Allow null operands as they act as sentinal barriers between successor
// operand lists.
for (auto elt : state.operands)
operands.push_back(cast_or_null<CFGValue>(elt));
auto *op =
OperationInst::create(state.location, state.name, operands, state.types,
state.attributes, state.successors, context);
auto *op = OperationInst::create(state.location, state.name, state.operands,
state.types, state.attributes,
state.successors, context);
block->getStatements().insert(insertPoint, op);
return op;
}
@ -311,23 +304,17 @@ OperationStmt *CFGFuncBuilder::createOperation(const OperationState &state) {
/// Create an operation given the fields represented as an OperationState.
OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
SmallVector<MLValue *, 8> operands;
operands.reserve(state.operands.size());
for (auto elt : state.operands)
operands.push_back(cast<MLValue>(elt));
auto *op =
OperationStmt::create(state.location, state.name, operands, state.types,
state.attributes, state.successors, context);
auto *op = OperationStmt::create(state.location, state.name, state.operands,
state.types, state.attributes,
state.successors, context);
block->getStatements().insert(insertPoint, op);
return op;
}
ForStmt *MLFuncBuilder::createFor(Location location,
ArrayRef<MLValue *> lbOperands,
AffineMap lbMap,
ArrayRef<MLValue *> ubOperands,
AffineMap ubMap, int64_t step) {
ArrayRef<Value *> lbOperands, AffineMap lbMap,
ArrayRef<Value *> ubOperands, AffineMap ubMap,
int64_t step) {
auto *stmt =
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getStatements().insert(insertPoint, stmt);
@ -341,7 +328,7 @@ ForStmt *MLFuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
return createFor(location, {}, lbMap, {}, ubMap, step);
}
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<MLValue *> operands,
IfStmt *MLFuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
auto *stmt = IfStmt::create(location, operands, set);
block->getStatements().insert(insertPoint, stmt);

View File

@ -20,8 +20,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
@ -54,7 +54,7 @@ void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<SSAValue *, 4> &operands,
SmallVector<Value *, 4> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
@ -76,7 +76,7 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
//===----------------------------------------------------------------------===//
void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap map, ArrayRef<SSAValue *> operands) {
AffineMap map, ArrayRef<Value *> operands) {
result->addOperands(operands);
result->types.append(map.getNumResults(), builder->getIndexType());
result->addAttribute("map", builder->getAffineMapAttr(map));
@ -133,24 +133,22 @@ bool AffineApplyOp::verify() const {
}
// The result of the affine apply operation can be used as a dimension id if it
// is a CFG value or if it is an MLValue, and all the operands are valid
// is a CFG value or if it is an Value, and all the operands are valid
// dimension ids.
bool AffineApplyOp::isValidDim() const {
for (auto *op : getOperands()) {
if (auto *v = dyn_cast<MLValue>(op))
if (!v->isValidDim())
return false;
if (!op->isValidDim())
return false;
}
return true;
}
// The result of the affine apply operation can be used as a symbol if it is
// a CFG value or if it is an MLValue, and all the operands are symbols.
// a CFG value or if it is an Value, and all the operands are symbols.
bool AffineApplyOp::isValidSymbol() const {
for (auto *op : getOperands()) {
if (auto *v = dyn_cast<MLValue>(op))
if (!v->isValidSymbol())
return false;
if (!op->isValidSymbol())
return false;
}
return true;
}
@ -170,13 +168,13 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
//===----------------------------------------------------------------------===//
void BranchOp::build(Builder *builder, OperationState *result, BasicBlock *dest,
ArrayRef<SSAValue *> operands) {
ArrayRef<Value *> operands) {
result->addSuccessor(dest, operands);
}
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
BasicBlock *dest;
SmallVector<SSAValue *, 4> destOperands;
SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
result->addSuccessor(dest, destOperands);
@ -212,17 +210,16 @@ void BranchOp::eraseOperand(unsigned index) {
//===----------------------------------------------------------------------===//
void CondBranchOp::build(Builder *builder, OperationState *result,
SSAValue *condition, BasicBlock *trueDest,
ArrayRef<SSAValue *> trueOperands,
BasicBlock *falseDest,
ArrayRef<SSAValue *> falseOperands) {
Value *condition, BasicBlock *trueDest,
ArrayRef<Value *> trueOperands, BasicBlock *falseDest,
ArrayRef<Value *> falseOperands) {
result->addOperands(condition);
result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands);
}
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<SSAValue *, 4> destOperands;
SmallVector<Value *, 4> destOperands;
BasicBlock *dest;
OpAsmParser::OperandType condInfo;
@ -446,7 +443,7 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result,
//===----------------------------------------------------------------------===//
void ReturnOp::build(Builder *builder, OperationState *result,
ArrayRef<SSAValue *> results) {
ArrayRef<Value *> results) {
result->addOperands(results);
}
@ -465,9 +462,10 @@ void ReturnOp::print(OpAsmPrinter *p) const {
*p << ' ';
p->printOperands(operand_begin(), operand_end());
*p << " : ";
interleave(operand_begin(), operand_end(),
[&](const SSAValue *e) { p->printType(e->getType()); },
[&]() { *p << ", "; });
interleave(
operand_begin(), operand_end(),
[&](const Value *e) { p->printType(e->getType()); },
[&]() { *p << ", "; });
}
}

View File

@ -23,7 +23,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is
@ -96,13 +95,13 @@ unsigned Operation::getNumOperands() const {
return llvm::cast<OperationStmt>(this)->getNumOperands();
}
SSAValue *Operation::getOperand(unsigned idx) {
Value *Operation::getOperand(unsigned idx) {
return llvm::cast<OperationStmt>(this)->getOperand(idx);
}
void Operation::setOperand(unsigned idx, SSAValue *value) {
void Operation::setOperand(unsigned idx, Value *value) {
auto *stmt = llvm::cast<OperationStmt>(this);
stmt->setOperand(idx, llvm::cast<MLValue>(value));
stmt->setOperand(idx, value);
}
/// Return the number of results this operation has.
@ -111,7 +110,7 @@ unsigned Operation::getNumResults() const {
}
/// Return the indicated result.
SSAValue *Operation::getResult(unsigned idx) {
Value *Operation::getResult(unsigned idx) {
return llvm::cast<OperationStmt>(this)->getResult(idx);
}
@ -585,8 +584,8 @@ bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
void impl::buildBinaryOp(Builder *builder, OperationState *result,
SSAValue *lhs, SSAValue *rhs) {
void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
assert(lhs->getType() == rhs->getType());
result->addOperands({lhs, rhs});
result->types.push_back(lhs->getType());
@ -613,8 +612,8 @@ void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
// CastOp implementation
//===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result,
SSAValue *source, Type destType) {
void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
Type destType) {
result->addOperands(source);
result->addTypes(destType);
}

View File

@ -16,8 +16,8 @@
// =============================================================================
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Value.h"
using namespace mlir;
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
@ -77,8 +77,8 @@ PatternRewriter::~PatternRewriter() {
/// clients can specify a list of other nodes that this replacement may make
/// (perhaps transitively) dead. If any of those ops are dead, this will
/// remove them as well.
void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
ArrayRef<SSAValue *> valuesToRemoveIfDead) {
void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
@ -97,15 +97,14 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<SSAValue *> newValues,
/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
Operation *op, Operation *newOp,
ArrayRef<SSAValue *> valuesToRemoveIfDead) {
Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
assert(op->getNumResults() == newOp->getNumResults() &&
"replacement op doesn't match results of original op");
if (op->getNumResults() == 1)
return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
SmallVector<SSAValue *, 8> newResults(newOp->getResults().begin(),
newOp->getResults().end());
SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
newOp->getResults().end());
return replaceOp(op, newResults, valuesToRemoveIfDead);
}
@ -118,7 +117,7 @@ void PatternRewriter::replaceOpWithResultsOfAnotherOp(
/// should remove if they are dead at this point.
///
void PatternRewriter::updatedRootInPlace(
Operation *op, ArrayRef<SSAValue *> valuesToRemoveIfDead) {
Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootUpdated(op);

View File

@ -1,4 +1,4 @@
//===- SSAValue.cpp - MLIR SSAValue Classes ------------===//
//===- SSAValue.cpp - MLIR ValueClasses ------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,15 +15,15 @@
// limitations under the License.
// =============================================================================
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Value.h"
using namespace mlir;
/// If this value is the result of an Instruction, return the instruction
/// that defines it.
OperationInst *SSAValue::getDefiningInst() {
OperationInst *Value::getDefiningInst() {
if (auto *result = dyn_cast<InstResult>(this))
return result->getOwner();
return nullptr;
@ -31,13 +31,13 @@ OperationInst *SSAValue::getDefiningInst() {
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *SSAValue::getDefiningStmt() {
OperationStmt *Value::getDefiningStmt() {
if (auto *result = dyn_cast<StmtResult>(this))
return result->getOwner();
return nullptr;
}
Operation *SSAValue::getDefiningOperation() {
Operation *Value::getDefiningOperation() {
if (auto *inst = getDefiningInst())
return inst;
if (auto *stmt = getDefiningStmt())
@ -45,14 +45,14 @@ Operation *SSAValue::getDefiningOperation() {
return nullptr;
}
/// Return the function that this SSAValue is defined in.
Function *SSAValue::getFunction() {
/// Return the function that this Valueis defined in.
Function *Value::getFunction() {
switch (getKind()) {
case SSAValueKind::BlockArgument:
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult:
case Value::Kind::StmtResult:
return getDefiningStmt()->getFunction();
case SSAValueKind::ForStmt:
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->getFunction();
}
}
@ -89,15 +89,6 @@ MLIRContext *IROperandOwner::getContext() const {
}
}
//===----------------------------------------------------------------------===//
// MLValue implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this MLValue is defined in.
MLFunction *MLValue::getFunction() {
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
}
//===----------------------------------------------------------------------===//
// BlockArgument implementation.
//===----------------------------------------------------------------------===//

View File

@ -85,18 +85,16 @@ MLFunction *Statement::getFunction() const {
return block ? block->getFunction() : nullptr;
}
MLValue *Statement::getOperand(unsigned idx) {
Value *Statement::getOperand(unsigned idx) { return getStmtOperand(idx).get(); }
const Value *Statement::getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
const MLValue *Statement::getOperand(unsigned idx) const {
return getStmtOperand(idx).get();
}
// MLValue can be used as a dimension id if it is valid as a symbol, or
// Value can be used as a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool MLValue::isValidDim() const {
bool Value::isValidDim() const {
if (auto *stmt = getDefiningStmt()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
@ -111,10 +109,10 @@ bool MLValue::isValidDim() const {
return true;
}
// MLValue can be used as a symbol if it is a constant, or it is defined at
// Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool MLValue::isValidSymbol() const {
bool Value::isValidSymbol() const {
if (auto *stmt = getDefiningStmt()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
@ -129,7 +127,7 @@ bool MLValue::isValidSymbol() const {
return isa<BlockArgument>(this);
}
void Statement::setOperand(unsigned idx, MLValue *value) {
void Statement::setOperand(unsigned idx, Value *value) {
getStmtOperand(idx).set(value);
}
@ -271,7 +269,7 @@ void Statement::dropAllReferences() {
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location location, OperationName name,
ArrayRef<MLValue *> operands,
ArrayRef<Value *> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
ArrayRef<StmtBlock *> successors,
@ -420,8 +418,8 @@ void OperationInst::eraseOperand(unsigned index) {
// ForStmt
//===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
assert(lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map");
@ -444,9 +442,9 @@ ForStmt *ForStmt::create(Location location, ArrayRef<MLValue *> lbOperands,
ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
: Statement(Statement::Kind::For, location),
Value(Value::Kind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
// The body of a for stmt always has one block.
@ -462,11 +460,11 @@ const AffineBound ForStmt::getUpperBound() const {
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands());
SmallVector<Value *, 4> ubOperands(getUpperBoundOperands());
operands.clear();
operands.reserve(lbOperands.size() + ubMap.getNumInputs());
@ -479,11 +477,11 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
this->lbMap = map;
}
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) {
void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands());
SmallVector<Value *, 4> lbOperands(getLowerBoundOperands());
operands.clear();
operands.reserve(lbOperands.size() + ubOperands.size());
@ -553,7 +551,7 @@ bool ForStmt::matchingBoundOperandList() const {
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare MLValue *'s.
// Compare Value *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
@ -581,7 +579,7 @@ IfStmt::~IfStmt() {
// allocated through MLIRContext's bump pointer allocator.
}
IfStmt *IfStmt::create(Location location, ArrayRef<MLValue *> operands,
IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
unsigned numOperands = operands.size();
assert(numOperands == set.getNumOperands() &&
@ -617,16 +615,16 @@ MLIRContext *IfStmt::getContext() const {
/// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map.
Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
MLIRContext *context) const {
// If the specified value is in operandMap, return the remapped value.
// Otherwise return the value itself.
auto remapOperand = [&](const MLValue *value) -> MLValue * {
auto remapOperand = [&](const Value *value) -> Value * {
auto it = operandMap.find(value);
return it != operandMap.end() ? it->second : const_cast<MLValue *>(value);
return it != operandMap.end() ? it->second : const_cast<Value *>(value);
};
SmallVector<MLValue *, 8> operands;
SmallVector<Value *, 8> operands;
SmallVector<StmtBlock *, 2> successors;
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
@ -683,10 +681,9 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
auto ubMap = forStmt->getUpperBoundMap();
auto *newFor = ForStmt::create(
getLoc(),
ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap,
ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap,
forStmt->getStep());
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
ubMap, forStmt->getStep());
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
@ -716,6 +713,6 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
}
Statement *Statement::clone(MLIRContext *context) const {
DenseMap<const MLValue *, MLValue *> operandMap;
DenseMap<const Value *, Value *> operandMap;
return clone(operandMap, context);
}

View File

@ -42,7 +42,6 @@
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include <algorithm>
using namespace mlir;
using llvm::MemoryBuffer;
using llvm::SMLoc;
@ -1890,10 +1889,10 @@ public:
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
/// Register a definition of a value with the symbol table.
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
// SSA parsing productions.
ParseResult parseSSAUse(SSAUseInfo &result);
@ -1903,9 +1902,9 @@ public:
ResultType parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type)> &action);
SSAValue *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>(
[&](SSAUseInfo useInfo, Type type) -> SSAValue * {
Value *parseSSAUseAndType() {
return parseSSADefOrUseAndType<Value *>(
[&](SSAUseInfo useInfo, Type type) -> Value * {
return resolveSSAUse(useInfo, type);
});
}
@ -1920,9 +1919,8 @@ public:
Operation *parseCustomOperation(const CreateOperationFunction &createOpFunc);
/// Parse a single operation successor and it's operand list.
virtual bool
parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands) = 0;
virtual bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<Value *> &operands) = 0;
protected:
FunctionParser(ParserState &state, Kind kind) : Parser(state), kind(kind) {}
@ -1934,24 +1932,23 @@ private:
Kind kind;
/// This keeps track of all of the SSA values we are tracking, indexed by
/// their name. This has one entry per result number.
llvm::StringMap<SmallVector<std::pair<SSAValue *, SMLoc>, 1>> values;
llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
/// These are all of the placeholders we've made along with the location of
/// their first reference, to allow checking for use of undefined values.
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
DenseMap<Value *, SMLoc> forwardReferencePlaceholders;
SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
Value *createForwardReferencePlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference.
bool isForwardReferencePlaceholder(SSAValue *value) {
bool isForwardReferencePlaceholder(Value *value) {
return forwardReferencePlaceholders.count(value);
}
};
} // end anonymous namespace
/// Create and remember a new placeholder for a forward reference.
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
Type type) {
Value *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, Type type) {
// Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain.
//
@ -1969,7 +1966,7 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
/// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure.
SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
Value *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = values[useInfo.name];
// If we have already seen a value of this name, return it.
@ -2010,7 +2007,7 @@ SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
}
/// Register a definition of a value with the symbol table.
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) {
ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, Value *value) {
auto &entries = values[useInfo.name];
// Make sure there is a slot for this value.
@ -2046,7 +2043,7 @@ ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
// Check for any forward references that are left. If we find any, error
// out.
if (!forwardReferencePlaceholders.empty()) {
SmallVector<std::pair<const char *, SSAValue *>, 4> errors;
SmallVector<std::pair<const char *, Value *>, 4> errors;
// Iteration over the map isn't deterministic, so sort by source location.
for (auto entry : forwardReferencePlaceholders)
errors.push_back({entry.second.getPointer(), entry.first});
@ -2399,9 +2396,8 @@ public:
return false;
}
bool
parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands) override {
bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<Value *> &operands) override {
// Defer successor parsing to the function parsers.
return parser.parseSuccessorAndUseList(dest, operands);
}
@ -2493,7 +2489,7 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) override {
SmallVectorImpl<Value *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
@ -2573,7 +2569,7 @@ public:
ParseResult parseFunctionBody();
bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands);
SmallVectorImpl<Value *> &operands);
private:
CFGFunction *function;
@ -2636,7 +2632,7 @@ private:
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
///
bool CFGFunctionParser::parseSuccessorAndUseList(
BasicBlock *&dest, SmallVectorImpl<SSAValue *> &operands) {
BasicBlock *&dest, SmallVectorImpl<Value *> &operands) {
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::bare_identifier))
return emitError("expected basic block name");
@ -2790,10 +2786,10 @@ private:
ParseResult parseForStmt();
ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfStmt();
ParseResult parseElseClause(StmtBlock *elseClause);
@ -2801,7 +2797,7 @@ private:
ParseResult parseStmtBlock(StmtBlock *block);
bool parseSuccessorAndUseList(BasicBlock *&dest,
SmallVectorImpl<SSAValue *> &operands) {
SmallVectorImpl<Value *> &operands) {
assert(false && "MLFunctions do not have terminators with successors.");
return true;
}
@ -2838,7 +2834,7 @@ ParseResult MLFunctionParser::parseForStmt() {
return ParseFailure;
// Parse lower bound.
SmallVector<MLValue *, 4> lbOperands;
SmallVector<Value *, 4> lbOperands;
AffineMap lbMap;
if (parseBound(lbOperands, lbMap, /*isLower*/ true))
return ParseFailure;
@ -2847,7 +2843,7 @@ ParseResult MLFunctionParser::parseForStmt() {
return ParseFailure;
// Parse upper bound.
SmallVector<MLValue *, 4> ubOperands;
SmallVector<Value *, 4> ubOperands;
AffineMap ubMap;
if (parseBound(ubOperands, ubMap, /*isLower*/ false))
return ParseFailure;
@ -2913,7 +2909,7 @@ ParseResult MLFunctionParser::parseIntConstant(int64_t &val) {
/// dim-and-symbol-use-list ::= dim-use-list symbol-use-list?
///
ParseResult
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName) {
if (parseToken(Token::l_paren, "expected '('"))
@ -2942,18 +2938,17 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
// Resolve SSA uses.
Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
Value *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval)
return ParseFailure;
auto *v = cast<MLValue>(sval);
if (i < numDims && !v->isValidDim())
if (i < numDims && !sval->isValidDim())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a dimension id");
if (i >= numDims && !v->isValidSymbol())
if (i >= numDims && !sval->isValidSymbol())
return emitError(opInfo[i].loc, "value '" + opInfo[i].name.str() +
"' cannot be used as a symbol");
operands.push_back(v);
operands.push_back(sval);
}
return ParseSuccess;
@ -2965,7 +2960,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
///
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
AffineMap &map, bool isLower) {
// 'min' / 'max' prefixes are syntactic sugar. Ignore them.
if (isLower)
@ -3003,7 +2998,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
// TODO: improve error message when SSA value is not an affine integer.
// Currently it is 'use of value ... expects different type than prior uses'
if (auto *value = resolveSSAUse(opInfo, builder.getIndexType()))
operands.push_back(cast<MLValue>(value));
operands.push_back(value);
else
return ParseFailure;
@ -3113,7 +3108,7 @@ ParseResult MLFunctionParser::parseIfStmt() {
if (!set)
return ParseFailure;
SmallVector<MLValue *, 4> operands;
SmallVector<Value *, 4> operands;
if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(),
"integer set"))
return ParseFailure;

View File

@ -23,8 +23,8 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
@ -78,8 +78,8 @@ struct MemRefCastFolder : public RewritePattern {
// AddFOp
//===----------------------------------------------------------------------===//
void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs) {
void AddFOp::build(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
assert(lhs->getType() == rhs->getType());
result->addOperands({lhs, rhs});
result->types.push_back(lhs->getType());
@ -146,7 +146,7 @@ void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<SSAValue *> operands) {
MemRefType memrefType, ArrayRef<Value *> operands) {
result->addOperands(operands);
result->types.push_back(memrefType);
}
@ -247,8 +247,8 @@ struct SimplifyAllocConst : public RewritePattern {
// and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants;
newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands;
SmallVector<Value *, 4> newOperands;
SmallVector<Value *, 4> droppedOperands;
unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@ -301,7 +301,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) {
ArrayRef<Value *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
@ -370,7 +370,7 @@ bool CallOp::verify() const {
//===----------------------------------------------------------------------===//
void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) {
Value *callee, ArrayRef<Value *> operands) {
auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee);
result->addOperands(operands);
@ -507,7 +507,7 @@ CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
}
void CmpIOp::build(Builder *build, OperationState *result,
CmpIPredicate predicate, SSAValue *lhs, SSAValue *rhs) {
CmpIPredicate predicate, Value *lhs, Value *rhs) {
result->addOperands({lhs, rhs});
result->types.push_back(getI1SameShape(build, lhs->getType()));
result->addAttribute(getPredicateAttrName(),
@ -580,8 +580,7 @@ bool CmpIOp::verify() const {
// DeallocOp
//===----------------------------------------------------------------------===//
void DeallocOp::build(Builder *builder, OperationState *result,
SSAValue *memref) {
void DeallocOp::build(Builder *builder, OperationState *result, Value *memref) {
result->addOperands(memref);
}
@ -615,7 +614,7 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void DimOp::build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index) {
Value *memrefOrTensor, unsigned index) {
result->addOperands(memrefOrTensor);
auto type = builder->getIndexType();
result->addAttribute("index", builder->getIntegerAttr(type, index));
@ -689,11 +688,11 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
// ---------------------------------------------------------------------------
void DmaStartOp::build(Builder *builder, OperationState *result,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
SSAValue *numElements, SSAValue *tagMemRef,
ArrayRef<SSAValue *> tagIndices, SSAValue *stride,
SSAValue *elementsPerStride) {
Value *srcMemRef, ArrayRef<Value *> srcIndices,
Value *destMemRef, ArrayRef<Value *> destIndices,
Value *numElements, Value *tagMemRef,
ArrayRef<Value *> tagIndices, Value *stride,
Value *elementsPerStride) {
result->addOperands(srcMemRef);
result->addOperands(srcIndices);
result->addOperands(destMemRef);
@ -836,8 +835,8 @@ void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// ---------------------------------------------------------------------------
void DmaWaitOp::build(Builder *builder, OperationState *result,
SSAValue *tagMemRef, ArrayRef<SSAValue *> tagIndices,
SSAValue *numElements) {
Value *tagMemRef, ArrayRef<Value *> tagIndices,
Value *numElements) {
result->addOperands(tagMemRef);
result->addOperands(tagIndices);
result->addOperands(numElements);
@ -896,8 +895,7 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate,
ArrayRef<SSAValue *> indices) {
Value *aggregate, ArrayRef<Value *> indices) {
auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate);
result->addOperands(indices);
@ -955,8 +953,8 @@ bool ExtractElementOp::verify() const {
// LoadOp
//===----------------------------------------------------------------------===//
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices) {
void LoadOp::build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices) {
auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref);
result->addOperands(indices);
@ -1130,9 +1128,8 @@ void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SelectOp
//===----------------------------------------------------------------------===//
void SelectOp::build(Builder *builder, OperationState *result,
SSAValue *condition, SSAValue *trueValue,
SSAValue *falseValue) {
void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
Value *trueValue, Value *falseValue) {
result->addOperands({condition, trueValue, falseValue});
result->addTypes(trueValue->getType());
}
@ -1201,8 +1198,8 @@ Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
//===----------------------------------------------------------------------===//
void StoreOp::build(Builder *builder, OperationState *result,
SSAValue *valueToStore, SSAValue *memref,
ArrayRef<SSAValue *> indices) {
Value *valueToStore, Value *memref,
ArrayRef<Value *> indices) {
result->addOperands(valueToStore);
result->addOperands(memref);
result->addOperands(indices);

View File

@ -72,10 +72,10 @@ static bool verifyPermutationMap(AffineMap permutationMap,
}
void VectorTransferReadOp::build(Builder *builder, OperationState *result,
VectorType vectorType, SSAValue *srcMemRef,
ArrayRef<SSAValue *> srcIndices,
VectorType vectorType, Value *srcMemRef,
ArrayRef<Value *> srcIndices,
AffineMap permutationMap,
Optional<SSAValue *> paddingValue) {
Optional<Value *> paddingValue) {
result->addOperands(srcMemRef);
result->addOperands(srcIndices);
if (paddingValue) {
@ -100,21 +100,20 @@ VectorTransferReadOp::getIndices() const {
return {begin, end};
}
Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
Optional<Value *> VectorTransferReadOp::getPaddingValue() {
auto memRefRank = getMemRefType().getRank();
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
return None;
}
return Optional<SSAValue *>(
getOperand(Offsets::FirstIndexOffset + memRefRank));
return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
}
Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
Optional<const Value *> VectorTransferReadOp::getPaddingValue() const {
auto memRefRank = getMemRefType().getRank();
if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
return None;
}
return Optional<const SSAValue *>(
return Optional<const Value *>(
getOperand(Offsets::FirstIndexOffset + memRefRank));
}
@ -136,7 +135,7 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) const {
// Construct the FunctionType and print it.
llvm::SmallVector<Type, 8> inputs{getMemRefType()};
// Must have at least one actual index, see verify.
const SSAValue *firstIndex = *(getIndices().begin());
const Value *firstIndex = *(getIndices().begin());
Type indexType = firstIndex->getType();
inputs.append(getMemRefType().getRank(), indexType);
if (optionalPaddingValue) {
@ -295,8 +294,8 @@ bool VectorTransferReadOp::verify() const {
// VectorTransferWriteOp
//===----------------------------------------------------------------------===//
void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
SSAValue *srcVector, SSAValue *dstMemRef,
ArrayRef<SSAValue *> dstIndices,
Value *srcVector, Value *dstMemRef,
ArrayRef<Value *> dstIndices,
AffineMap permutationMap) {
result->addOperands({srcVector, dstMemRef});
result->addOperands(dstIndices);
@ -457,7 +456,7 @@ bool VectorTransferWriteOp::verify() const {
// VectorTypeCastOp
//===----------------------------------------------------------------------===//
void VectorTypeCastOp::build(Builder *builder, OperationState *result,
SSAValue *srcVector, Type dstType) {
Value *srcVector, Type dstType) {
result->addOperands(srcVector);
result->addTypes(dstType);
}

View File

@ -111,7 +111,7 @@ private:
/// descriptor and get the pointer to the element indexed by the linearized
/// subscript. Return nullptr on errors.
llvm::Value *emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op,
const Value *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices);
/// Emit LLVM IR corresponding to the given Alloc `op`. In particular, create
@ -136,12 +136,12 @@ private:
/// Create a single LLVM value of struct type that includes the list of
/// given MLIR values. The `values` list must contain at least 2 elements.
llvm::Value *packValues(ArrayRef<const SSAValue *> values);
llvm::Value *packValues(ArrayRef<const Value *> values);
/// Extract a list of `num` LLVM values from a `value` of struct type.
SmallVector<llvm::Value *, 4> unpackValues(llvm::Value *value, unsigned num);
llvm::DenseMap<const Function *, llvm::Function *> functionMapping;
llvm::DenseMap<const SSAValue *, llvm::Value *> valueMapping;
llvm::DenseMap<const Value *, llvm::Value *> valueMapping;
llvm::DenseMap<const BasicBlock *, llvm::BasicBlock *> blockMapping;
llvm::LLVMContext &llvmContext;
llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> builder;
@ -316,7 +316,7 @@ static bool checkSupportedMemRefType(MemRefType type, const Operation &op) {
}
llvm::Value *ModuleLowerer::emitMemRefElementAccess(
const SSAValue *memRef, const Operation &op,
const Value *memRef, const Operation &op,
llvm::iterator_range<Operation::const_operand_iterator> opIndices) {
auto type = memRef->getType().dyn_cast<MemRefType>();
assert(type && "expected memRef value to have a MemRef type");
@ -340,7 +340,7 @@ llvm::Value *ModuleLowerer::emitMemRefElementAccess(
// Obtain the list of access subscripts as values and linearize it given the
// list of sizes.
auto indices = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); },
[this](const Value *value) { return valueMapping.lookup(value); },
opIndices);
auto subscript = linearizeSubscripts(indices, sizes);
@ -460,11 +460,11 @@ llvm::Value *ModuleLowerer::emitConstantSplat(const ConstantOp &op) {
}
// Create an undef struct value and insert individual values into it.
llvm::Value *ModuleLowerer::packValues(ArrayRef<const SSAValue *> values) {
llvm::Value *ModuleLowerer::packValues(ArrayRef<const Value *> values) {
assert(values.size() > 1 && "cannot pack less than 2 values");
auto types =
functional::map([](const SSAValue *v) { return v->getType(); }, values);
functional::map([](const Value *v) { return v->getType(); }, values);
llvm::Type *packedType = getPackedResultType(types);
llvm::Value *packed = llvm::UndefValue::get(packedType);
@ -641,7 +641,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
return false;
}
if (auto dimOp = inst.dyn_cast<DimOp>()) {
const SSAValue *container = dimOp->getOperand();
const Value *container = dimOp->getOperand();
MemRefType type = container->getType().dyn_cast<MemRefType>();
if (!type)
return dimOp->emitError("only memref types are supported");
@ -672,7 +672,7 @@ bool ModuleLowerer::convertInstruction(const OperationInst &inst) {
if (auto callOp = inst.dyn_cast<CallOp>()) {
auto operands = functional::map(
[this](const SSAValue *value) { return valueMapping.lookup(value); },
[this](const Value *value) { return valueMapping.lookup(value); },
callOp->getOperands());
auto numResults = callOp->getNumResults();
llvm::Value *result =
@ -779,10 +779,9 @@ bool ModuleLowerer::convertBasicBlock(const BasicBlock &bb,
// Get the SSA value passed to the current block from the terminator instruction
// of its predecessor.
static const SSAValue *getPHISourceValue(const BasicBlock *current,
const BasicBlock *pred,
unsigned numArguments,
unsigned index) {
static const Value *getPHISourceValue(const BasicBlock *current,
const BasicBlock *pred,
unsigned numArguments, unsigned index) {
auto &terminator = *pred->getTerminator();
if (terminator.isa<BranchOp>()) {
return terminator.getOperand(index);

View File

@ -30,13 +30,12 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding.
SmallVector<SSAValue *, 8> existingConstants;
SmallVector<Value *, 8> existingConstants;
// Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
bool foldOperation(Operation *op, SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory);
void visitOperationStmt(OperationStmt *stmt);
void visitForStmt(ForStmt *stmt);
@ -54,9 +53,8 @@ char ConstantFold::passID = 0;
///
/// This returns false if the operation was successfully folded.
bool ConstantFold::foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory) {
// If this operation is already a constant, just remember it for cleanup
// later, and don't try to fold it.
if (auto constant = op->dyn_cast<ConstantOp>()) {
@ -114,7 +112,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
if (!inst)
continue;
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
builder.setInsertionPoint(inst);
return builder.create<ConstantOp>(inst->getLoc(), value, type);
};
@ -142,7 +140,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
};

View File

@ -50,28 +50,28 @@ public:
void visitOperationStmt(OperationStmt *opStmt);
private:
CFGValue *getConstantIndexValue(int64_t value);
Value *getConstantIndexValue(int64_t value);
void visitStmtBlock(StmtBlock *stmtBlock);
CFGValue *buildMinMaxReductionSeq(
Value *buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values);
CFGFunction *cfgFunc;
CFGFuncBuilder builder;
// Mapping between original MLValues and lowered CFGValues.
llvm::DenseMap<const MLValue *, CFGValue *> valueRemapping;
// Mapping between original Values and lowered Values.
llvm::DenseMap<const Value *, Value *> valueRemapping;
};
} // end anonymous namespace
// Return a vector of OperationStmt's arguments as SSAValues. For each
// statement operands, represented as MLValue, lookup its CFGValue conterpart in
// Return a vector of OperationStmt's arguments as Values. For each
// statement operands, represented as Value, lookup its Value conterpart in
// the valueRemapping table.
static llvm::SmallVector<SSAValue *, 4>
static llvm::SmallVector<mlir::Value *, 4>
operandsAs(Statement *opStmt,
const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) {
llvm::SmallVector<SSAValue *, 4> operands;
for (const MLValue *operand : opStmt->getOperands()) {
const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
llvm::SmallVector<Value *, 4> operands;
for (const Value *operand : opStmt->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand));
}
@ -81,8 +81,8 @@ operandsAs(Statement *opStmt,
// Convert an operation statement into an operation instruction.
//
// The operation description (name, number and types of operands or results)
// remains the same but the values must be updated to be CFGValues. Update the
// mapping MLValue->CFGValue as the conversion is performed. The operation
// remains the same but the values must be updated to be Values. Update the
// mapping Value->Value as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region).
void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
// Set up basic operation state (context, name, operands).
@ -90,11 +90,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
opStmt->getName());
state.addOperands(operandsAs(opStmt, valueRemapping));
// Set up operation return types. The corresponding SSAValues will become
// Set up operation return types. The corresponding Values will become
// available after the operation is created.
state.addTypes(
functional::map([](SSAValue *result) { return result->getType(); },
opStmt->getResults()));
state.addTypes(functional::map(
[](Value *result) { return result->getType(); }, opStmt->getResults()));
// Copy attributes.
for (auto attr : opStmt->getAttrs()) {
@ -112,10 +111,10 @@ void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
}
}
// Create a CFGValue for the given integer constant of index type.
CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) {
// Create a Value for the given integer constant of index type.
Value *FunctionConverter::getConstantIndexValue(int64_t value) {
auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value);
return cast<CFGValue>(op->getResult());
return op->getResult();
}
// Visit all statements in the given statement block.
@ -135,18 +134,18 @@ void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) {
// Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes.
CFGValue *FunctionConverter::buildMinMaxReductionSeq(
Value *FunctionConverter::buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
CFGValue *value = cast<CFGValue>(*valueIt++);
Value *value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
auto selectOp =
builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt);
value = cast<CFGValue>(selectOp->getResult());
value = selectOp->getResult();
}
return value;
@ -231,9 +230,9 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block
// argument remapping available to the following instructions. ForStatement
// is-a MLValue corresponding to the loop induction variable.
// is-a Value corresponding to the loop induction variable.
builder.setInsertionPoint(loopConditionBlock);
CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType());
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
valueRemapping.insert(std::make_pair(forStmt, iv));
// Recursively construct loop body region.
@ -251,7 +250,7 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp =
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult(0));
Value *nextIvValue = stepOp->getResult(0);
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
nextIvValue);
@ -260,20 +259,19 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPoint(loopInitBlock);
// Compute loop bounds using affine_apply after remapping its operands.
auto remapOperands = [this](const SSAValue *value) -> SSAValue * {
const MLValue *mlValue = dyn_cast<MLValue>(value);
return valueRemapping.lookup(mlValue);
auto remapOperands = [this](const Value *value) -> Value * {
return valueRemapping.lookup(value);
};
auto operands =
functional::map(remapOperands, forStmt->getLowerBoundOperands());
auto lbAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
CFGValue *lowerBound = buildMinMaxReductionSeq(
Value *lowerBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
auto ubAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
CFGValue *upperBound = buildMinMaxReductionSeq(
Value *upperBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound);
@ -281,10 +279,10 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPoint(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = cast<CFGValue>(comparisonOp->getResult());
auto comparisonResult = comparisonOp->getResult();
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, ArrayRef<SSAValue *>(),
postLoopBlock, ArrayRef<SSAValue *>());
loopBodyFirstBlock, ArrayRef<Value *>(),
postLoopBlock, ArrayRef<Value *>());
// Finally, make sure building can continue by setting the post-loop block
// (end of loop SESE region) as the insertion point.
@ -401,7 +399,7 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// If the test succeeds, jump to the next block testing testing the next
// conjunct of the condition in the similar way. When all conjuncts have been
// handled, jump to the 'then' block instead.
SSAValue *zeroConstant = getConstantIndexValue(0);
Value *zeroConstant = getConstantIndexValue(0);
ifConditionExtraBlocks.push_back(thenBlock);
for (auto tuple :
llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(),
@ -416,16 +414,16 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
integerSet.getNumSymbols(), constraintExpr, {});
auto affineApplyOp = builder.create<AffineApplyOp>(
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
SSAValue *affResult = affineApplyOp->getResult(0);
Value *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>(
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
affResult, zeroConstant);
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
nextBlock, /*trueArgs*/ ArrayRef<SSAValue *>(),
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
elseBlock,
/*falseArgs*/ ArrayRef<SSAValue *>());
/*falseArgs*/ ArrayRef<Value *>());
builder.setInsertionPoint(nextBlock);
}
ifConditionExtraBlocks.pop_back();
@ -468,10 +466,10 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// of the current region. The SESE invariant allows us to easily handle nested
// structures of arbitrary complexity.
//
// During the conversion, we maintain a mapping between the MLValues present in
// the original function and their CFGValue images in the function under
// construction. When an MLValue is used, it gets replaced with the
// corresponding CFGValue that has been defined previously. The value flow
// During the conversion, we maintain a mapping between the Values present in
// the original function and their Value images in the function under
// construction. When an Value is used, it gets replaced with the
// corresponding Value that has been defined previously. The value flow
// starts with function arguments converted to basic block arguments.
CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
auto outerBlock = builder.createBlock();
@ -482,8 +480,8 @@ CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
outerBlock->addArguments(mlFunc->getType().getInputs());
assert(mlFunc->getNumArguments() == outerBlock->getNumArguments());
for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) {
const MLValue *mlArgument = mlFunc->getArgument(i);
CFGValue *cfgArgument = outerBlock->getArgument(i);
const Value *mlArgument = mlFunc->getArgument(i);
Value *cfgArgument = outerBlock->getArgument(i);
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
}

View File

@ -76,7 +76,7 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
// Map from original memref's to the DMA buffers that their accesses are
// replaced with.
DenseMap<SSAValue *, SSAValue *> fastBufferMap;
DenseMap<Value *, Value *> fastBufferMap;
// Slow memory space associated with DMAs.
const unsigned slowMemorySpace;
@ -195,11 +195,11 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// Indices to use for the DmaStart op.
// Indices for the original memref being DMAed from/to.
SmallVector<SSAValue *, 4> memIndices;
SmallVector<Value *, 4> memIndices;
// Indices for the faster buffer being DMAed into/from.
SmallVector<SSAValue *, 4> bufIndices;
SmallVector<Value *, 4> bufIndices;
SSAValue *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
Value *zeroIndex = top.create<ConstantIndexOp>(loc, 0);
unsigned rank = memRefType.getRank();
SmallVector<int, 4> fastBufferShape;
@ -226,10 +226,10 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// DMA generation is being done.
const FlatAffineConstraints *cst = region.getConstraints();
auto ids = cst->getIds();
SmallVector<SSAValue *, 8> outerIVs;
SmallVector<Value *, 8> outerIVs;
for (unsigned i = rank, e = ids.size(); i < e; i++) {
auto id = cst->getIds()[i];
assert(id.hasValue() && "MLValue id expected");
assert(id.hasValue() && "Value id expected");
outerIVs.push_back(id.getValue());
}
@ -253,15 +253,15 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
// Set DMA start location for this dimension in the lower memory space
// memref.
if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
memIndices.push_back(cast<MLValue>(
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult()));
memIndices.push_back(
top.create<ConstantIndexOp>(loc, caf.getValue())->getResult());
} else {
// The coordinate for the start location is just the lower bound along the
// corresponding dimension on the memory region (stored in 'offset').
auto map = top.getAffineMap(
cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {});
memIndices.push_back(cast<MLValue>(
b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0)));
memIndices.push_back(
b->create<AffineApplyOp>(loc, map, outerIVs)->getResult(0));
}
// The fast buffer is DMAed into at location zero; addressing is relative.
bufIndices.push_back(zeroIndex);
@ -272,7 +272,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
}
// The faster memory space buffer.
SSAValue *fastMemRef;
Value *fastMemRef;
// Check if a buffer was already created.
// TODO(bondhugula): union across all memory op's per buffer. For now assuming
@ -321,8 +321,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
return false;
}
SSAValue *stride = nullptr;
SSAValue *numEltPerStride = nullptr;
Value *stride = nullptr;
Value *numEltPerStride = nullptr;
if (!strideInfos.empty()) {
stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
numEltPerStride =
@ -362,7 +362,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
// *Only* those uses within the body of 'forStmt' are replaced.
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*domStmtFilter=*/&*forStmt->getBody()->begin());

View File

@ -83,22 +83,22 @@ FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
static void getSingleMemRefAccess(OperationStmt *loadOrStoreOpStmt,
MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->memref = cast<MLValue>(loadOp->getMemRef());
access->memref = loadOp->getMemRef();
access->opStmt = loadOrStoreOpStmt;
auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index));
access->indices.push_back(index);
}
} else {
assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->opStmt = loadOrStoreOpStmt;
access->memref = cast<MLValue>(storeOp->getMemRef());
access->memref = storeOp->getMemRef();
auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank());
for (auto *index : storeOp->getIndices()) {
access->indices.push_back(cast<MLValue>(index));
access->indices.push_back(index);
}
}
}
@ -178,20 +178,20 @@ public:
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
// Returns the load op count for 'memref'.
unsigned getLoadOpCount(MLValue *memref) {
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
for (auto *loadOpStmt : loads) {
if (memref == cast<MLValue>(loadOpStmt->cast<LoadOp>()->getMemRef()))
if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
++loadOpCount;
}
return loadOpCount;
}
// Returns the store op count for 'memref'.
unsigned getStoreOpCount(MLValue *memref) {
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
for (auto *storeOpStmt : stores) {
if (memref == cast<MLValue>(storeOpStmt->cast<StoreOp>()->getMemRef()))
if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
++storeOpCount;
}
return storeOpCount;
@ -203,7 +203,7 @@ public:
// The id of the node at the other end of the edge.
unsigned id;
// The memref on which this edge represents a dependence.
MLValue *memref;
Value *memref;
};
// Map from node id to Node.
@ -227,13 +227,13 @@ public:
}
// Adds an edge from node 'srcId' to node 'dstId' for 'memref'.
void addEdge(unsigned srcId, unsigned dstId, MLValue *memref) {
void addEdge(unsigned srcId, unsigned dstId, Value *memref) {
outEdges[srcId].push_back({dstId, memref});
inEdges[dstId].push_back({srcId, memref});
}
// Removes an edge from node 'srcId' to node 'dstId' for 'memref'.
void removeEdge(unsigned srcId, unsigned dstId, MLValue *memref) {
void removeEdge(unsigned srcId, unsigned dstId, Value *memref) {
assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0);
// Remove 'srcId' from 'inEdges[dstId]'.
@ -253,7 +253,7 @@ public:
}
// Returns the input edge count for node 'id' and 'memref'.
unsigned getInEdgeCount(unsigned id, MLValue *memref) {
unsigned getInEdgeCount(unsigned id, Value *memref) {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id])
@ -263,7 +263,7 @@ public:
}
// Returns the output edge count for node 'id' and 'memref'.
unsigned getOutEdgeCount(unsigned id, MLValue *memref) {
unsigned getOutEdgeCount(unsigned id, Value *memref) {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
@ -347,7 +347,7 @@ public:
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(MLFunction *f) {
unsigned id = 0;
DenseMap<MLValue *, SetVector<unsigned>> memrefAccesses;
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
// Create graph node 'id' to represent top-level 'forStmt' and record
@ -360,12 +360,12 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
Node node(id++, &stmt);
for (auto *opStmt : collector.loadOpStmts) {
node.loads.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef());
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opStmt : collector.storeOpStmts) {
node.stores.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef());
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
nodes.insert({node.id, node});
@ -375,7 +375,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
// Create graph node for top-level load op.
Node node(id++, &stmt);
node.loads.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<LoadOp>()->getMemRef());
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
@ -383,7 +383,7 @@ bool MemRefDependenceGraph::init(MLFunction *f) {
// Create graph node for top-level store op.
Node node(id++, &stmt);
node.stores.push_back(opStmt);
auto *memref = cast<MLValue>(opStmt->cast<StoreOp>()->getMemRef());
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
@ -477,8 +477,7 @@ public:
SmallVector<OperationStmt *, 4> loads = dstNode->loads;
while (!loads.empty()) {
auto *dstLoadOpStmt = loads.pop_back_val();
auto *memref =
cast<MLValue>(dstLoadOpStmt->cast<LoadOp>()->getMemRef());
auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
// Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1)
continue;

View File

@ -85,10 +85,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
for (unsigned i = 0; i < width; i++) {
auto lbOperands = origLoops[i]->getLowerBoundOperands();
auto ubOperands = origLoops[i]->getUpperBoundOperands();
SmallVector<MLValue *, 4> newLbOperands(lbOperands.begin(),
lbOperands.end());
SmallVector<MLValue *, 4> newUbOperands(ubOperands.begin(),
ubOperands.end());
SmallVector<Value *, 4> newLbOperands(lbOperands.begin(), lbOperands.end());
SmallVector<Value *, 4> newUbOperands(ubOperands.begin(), ubOperands.end());
newLoops[i]->setLowerBound(newLbOperands, origLoops[i]->getLowerBoundMap());
newLoops[i]->setUpperBound(newUbOperands, origLoops[i]->getUpperBoundMap());
newLoops[i]->setStep(tileSizes[i]);
@ -112,8 +110,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
// Construct the upper bound map; the operands are the original operands
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
SmallVector<MLValue *, 4> ubOperands(
origLoops[i]->getUpperBoundOperands());
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
ubOperands.push_back(newLoops[i]);
auto origUbMap = origLoops[i]->getUpperBoundMap();
@ -191,8 +188,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
SmallVector<MLValue *, 6> origLoopIVs(band.begin(), band.end());
SmallVector<Optional<MLValue *>, 6> ids(band.begin(), band.end());
SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
FlatAffineConstraints cst;
getIndexSet(band, &cst);

View File

@ -191,7 +191,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// unrollJamFactor.
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
DenseMap<const Value *, Value *> operandMap;
// Insert the cleanup loop right after 'forStmt'.
MLFuncBuilder builder(forStmt->getBlock(),
std::next(StmtBlock::iterator(forStmt)));
@ -219,7 +219,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Unroll and jam (appends unrollJamFactor-1 additional copies).
for (unsigned i = 1; i < unrollJamFactor; i++) {
DenseMap<const MLValue *, MLValue *> operandMapping;
DenseMap<const Value *, Value *> operandMapping;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
@ -230,7 +230,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0);
operandMapping[forStmt] = cast<MLValue>(ivUnroll);
operandMapping[forStmt] = ivUnroll;
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {

View File

@ -29,17 +29,14 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/MLPatternLoweringPass.h"
#include "mlir/Transforms/Passes.h"
@ -62,26 +59,26 @@ using namespace mlir;
#define DEBUG_TYPE "lower-vector-transfers"
/// Creates the SSAValue for the sum of `a` and `b` without building a
/// Creates the Value for the sum of `a` and `b` without building a
/// full-fledged AffineMap for all indices.
///
/// Prerequisites:
/// `a` and `b` must be of IndexType.
static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) {
static mlir::Value *add(MLFuncBuilder *b, Location loc, Value *v, Value *w) {
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
auto *context = b->getContext();
auto d0 = getAffineDimExpr(0, context);
auto d1 = getAffineDimExpr(1, context);
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w})
return b->create<AffineApplyOp>(loc, map, ArrayRef<mlir::Value *>{v, w})
->getResult(0);
}
namespace {
struct LowerVectorTransfersState : public MLFuncGlobalLoweringState {
// Top of the function constant zero index.
SSAValue *zero;
Value *zero;
};
} // namespace
@ -131,7 +128,8 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// case of GPUs.
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
vecView->getResult(), ArrayRef<SSAValue *>{state->zero});
vecView->getResult(),
ArrayRef<mlir::Value *>{state->zero});
}
// 3. Emit the loop-nest.
@ -140,7 +138,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// TODO(ntv): Handle broadcast / slice properly.
auto permutationMap = transfer->getPermutationMap();
SetVector<ForStmt *> loops;
SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices());
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
auto composed = composeWithUnboundedMap(
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
@ -168,17 +166,16 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// b. write scalar to local.
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
transfer->getMemRef(), accessIndices);
b.create<StoreOp>(
transfer->getLoc(), scalarLoad->getResult(),
tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops));
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
tmpScalarAlloc->getResult(),
functional::map([](Value *val) { return val; }, loops));
} else {
// VectorTransferWriteOp.
// a. read scalar from local;
// b. write scalar to remote.
auto scalarLoad = b.create<LoadOp>(
transfer->getLoc(), tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops));
functional::map([](Value *val) { return val; }, loops));
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
transfer->getMemRef(), accessIndices);
}
@ -186,11 +183,11 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// 5. Read the vector from local storage in case of a vector_transfer_read.
// TODO(ntv): This vector_load operation should be further lowered in the
// case of GPUs.
llvm::SmallVector<SSAValue *, 1> newResults = {};
llvm::SmallVector<Value *, 1> newResults = {};
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(),
ArrayRef<SSAValue *>{state->zero})
ArrayRef<Value *>{state->zero})
->getResult();
newResults.push_back(vector);
}

View File

@ -32,9 +32,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
@ -192,7 +190,7 @@ struct MaterializationState {
VectorType superVectorType;
VectorType hwVectorType;
SmallVector<unsigned, 8> hwVectorInstance;
DenseMap<const MLValue *, MLValue *> *substitutionsMap;
DenseMap<const Value *, Value *> *substitutionsMap;
};
struct MaterializeVectorsPass : public FunctionPass {
@ -250,9 +248,9 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap);
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all SSAValue belong to a program slice scoped within the immediately
/// Not all Values belong to a program slice scoped within the immediately
/// enclosing loop.
/// One simple example is constants defined outside the innermost loop scope.
/// For such cases the substitutionsMap has no entry and we allow an additional
@ -261,17 +259,16 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
/// indices and will need to be extended in the future.
///
/// If substitution fails, returns nullptr.
static MLValue *
substitute(SSAValue *v, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
auto it = substitutionsMap->find(cast<MLValue>(v));
static Value *substitute(Value *v, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
auto it = substitutionsMap->find(v);
if (it == substitutionsMap->end()) {
auto *opStmt = cast<OperationStmt>(v->getDefiningOperation());
if (opStmt->isa<ConstantOp>()) {
MLFuncBuilder b(opStmt);
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
auto res = substitutionsMap->insert(
std::make_pair(cast<MLValue>(v), cast<MLValue>(inst->getResult(0))));
auto res =
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
assert(res.second && "Insertion failed");
return res.first->second;
}
@ -336,10 +333,10 @@ substitute(SSAValue *v, VectorType hwVectorType,
/// TODO(ntv): support a concrete AffineMap and compose with it.
/// TODO(ntv): these implementation details should be captured in a
/// vectorization trait at the op level directly.
static SmallVector<SSAValue *, 8>
static SmallVector<mlir::Value *, 8>
reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
ArrayRef<SSAValue *> memrefIndices) {
ArrayRef<Value *> memrefIndices) {
auto vectorShape = hwVectorType.getShape();
assert(hwVectorInstance.size() >= vectorShape.size());
@ -380,7 +377,7 @@ reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
// TODO(ntv): support a concrete map and composition.
auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(),
affineMap, memrefIndices);
return SmallVector<SSAValue *, 8>{app->getResults()};
return SmallVector<mlir::Value *, 8>{app->getResults()};
}
/// Returns attributes with the following substitutions applied:
@ -402,21 +399,21 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
/// Creates an instantiated version of `opStmt`.
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
/// affine reindexing. Just substitute their SSAValue* operands and be done. For
/// this case the actual instance is irrelevant. Just use the SSA values in
/// affine reindexing. Just substitute their Value operands and be done. For
/// this case the actual instance is irrelevant. Just use the values in
/// substitutionsMap.
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opStmt->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
assert(!opStmt->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp");
bool fail = false;
auto operands = map(
[hwVectorType, substitutionsMap, &fail](SSAValue *v) -> SSAValue * {
[hwVectorType, substitutionsMap, &fail](Value *v) -> Value * {
auto *res =
fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
fail |= !res;
@ -481,9 +478,9 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
SmallVector<SSAValue *, 8> indices =
map(makePtrDynCaster<SSAValue>(), read->getIndices());
DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
map(makePtrDynCaster<Value>(), read->getIndices());
auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto cloned = b->create<VectorTransferReadOp>(
@ -501,9 +498,9 @@ instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
static OperationStmt *
instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
SmallVector<SSAValue *, 8> indices =
map(makePtrDynCaster<SSAValue>(), write->getIndices());
DenseMap<const Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
map(makePtrDynCaster<Value>(), write->getIndices());
auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto cloned = b->create<VectorTransferWriteOp>(
@ -555,8 +552,8 @@ static bool instantiateMaterialization(Statement *stmt,
} else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
state->substitutionsMap->insert(std::make_pair(
cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0))));
state->substitutionsMap->insert(
std::make_pair(read->getResult(), clone->getResult(0)));
return false;
}
// The only op with 0 results reaching this point must, by construction, be
@ -571,8 +568,8 @@ static bool instantiateMaterialization(Statement *stmt,
if (!clone) {
return true;
}
state->substitutionsMap->insert(std::make_pair(
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
state->substitutionsMap->insert(
std::make_pair(opStmt->getResult(0), clone->getResult(0)));
return false;
}
@ -610,7 +607,7 @@ static bool emitSlice(MaterializationState *state,
// Fresh RAII instanceIndices and substitutionsMap.
MaterializationState scopedState = *state;
scopedState.hwVectorInstance = delinearize(idx, *ratio);
DenseMap<const MLValue *, MLValue *> substitutionMap;
DenseMap<const Value *, Value *> substitutionMap;
scopedState.substitutionsMap = &substitutionMap;
// slice are topologically sorted, we can just clone them in order.
for (auto *stmt : *slice) {

View File

@ -32,7 +32,6 @@
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pipeline-data-transfer"
using namespace mlir;
@ -80,7 +79,7 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// of the old memref by the new one while indexing the newly added dimension by
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto *forBody = forStmt->getBody();
MLFuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@ -103,7 +102,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// Put together alloc operands for the dynamic dimensions of the memref.
MLFuncBuilder bOuter(forStmt);
SmallVector<SSAValue *, 4> allocOperands;
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
@ -114,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// Create and place the alloc right before the 'for' statement.
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
// inserting a dealloc -- this isn't the right thing.
SSAValue *newMemRef =
Value *newMemRef =
bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
@ -126,8 +125,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
// replaceAllMemRefUsesWith will always succeed unless the forStmt body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
ivModTwoOp->getResult(0), AffineMap::Null(), {},
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
AffineMap::Null(), {},
&*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
@ -225,8 +224,7 @@ static void findMatchingStartFinishStmts(
continue;
// We only double buffer if the buffer is not live out of loop.
const MLValue *memref =
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
@ -280,8 +278,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// dimension.
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()));
Value *oldMemRef = dmaStartStmt->getOperand(
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
if (!doubleBuffer(oldMemRef, forStmt)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
@ -302,8 +300,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
auto *dmaFinishStmt = pair.second;
MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
Value *oldTagMemRef =
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
@ -332,7 +330,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
SmallVector<OperationStmt *, 4> affineApplyStmts;
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) {
stmtShiftMap[stmt] = 0;

View File

@ -217,7 +217,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// If we already have a canonicalized version of this constant, just
// reuse it. Otherwise create a new one.
SSAValue *cstValue;
Value *cstValue;
auto it = uniquedConstants.find({resultConstants[i], res->getType()});
if (it != uniquedConstants.end())
cstValue = it->second->getResult(0);

View File

@ -31,7 +31,6 @@
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "LoopUtils"
using namespace mlir;
@ -108,8 +107,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
forStmt->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forStmt->getLowerBound();
SmallVector<SSAValue *, 4> lbOperands(lb.operand_begin(),
lb.operand_end());
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
MLFuncBuilder builder(forStmt->getBlock(), StmtBlock::iterator(forStmt));
auto affineApplyOp = builder.create<AffineApplyOp>(
forStmt->getLoc(), lb.getMap(), lbOperands);
@ -149,8 +147,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
SmallVector<MLValue *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<MLValue *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
@ -176,7 +174,7 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
srcForStmt->getStep() * shift)),
loopChunk)
->getResult(0);
operandMap[srcForStmt] = cast<MLValue>(ivRemap);
operandMap[srcForStmt] = ivRemap;
} else {
operandMap[srcForStmt] = loopChunk;
}
@ -380,7 +378,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
DenseMap<const MLValue *, MLValue *> operandMap;
DenseMap<const Value *, Value *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
@ -414,7 +412,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
DenseMap<const MLValue *, MLValue *> operandMap;
DenseMap<const Value *, Value *> operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
@ -425,7 +423,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0);
operandMap[forStmt] = cast<MLValue>(ivUnroll);
operandMap[forStmt] = ivUnroll;
}
// Clone the original body of 'forStmt'.

View File

@ -32,17 +32,17 @@ using namespace mlir;
namespace {
// Visit affine expressions recursively and build the sequence of instructions
// that correspond to it. Visitation functions return an SSAValue of the
// that correspond to it. Visitation functions return an Value of the
// expression subtree they visited or `nullptr` on error.
class AffineApplyExpander
: public AffineExprVisitor<AffineApplyExpander, SSAValue *> {
: public AffineExprVisitor<AffineApplyExpander, Value *> {
public:
// This internal clsas expects arguments to be non-null, checks must be
// performed at the call site.
AffineApplyExpander(FuncBuilder *builder, AffineApplyOp *op)
: builder(*builder), applyOp(*op), loc(op->getLoc()) {}
template <typename OpTy> SSAValue *buildBinaryExpr(AffineBinaryOpExpr expr) {
template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
@ -51,33 +51,33 @@ public:
return op->getResult();
}
SSAValue *visitAddExpr(AffineBinaryOpExpr expr) {
Value *visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<AddIOp>(expr);
}
SSAValue *visitMulExpr(AffineBinaryOpExpr expr) {
Value *visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<MulIOp>(expr);
}
// TODO(zinenko): implement when the standard operators are made available.
SSAValue *visitModExpr(AffineBinaryOpExpr) {
Value *visitModExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc, "unsupported binary operator: mod");
return nullptr;
}
SSAValue *visitFloorDivExpr(AffineBinaryOpExpr) {
Value *visitFloorDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc,
"unsupported binary operator: floor_div");
return nullptr;
}
SSAValue *visitCeilDivExpr(AffineBinaryOpExpr) {
Value *visitCeilDivExpr(AffineBinaryOpExpr) {
builder.getContext()->emitError(loc,
"unsupported binary operator: ceil_div");
return nullptr;
}
SSAValue *visitConstantExpr(AffineConstantExpr expr) {
Value *visitConstantExpr(AffineConstantExpr expr) {
auto valueAttr =
builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
auto op =
@ -85,7 +85,7 @@ public:
return op->getResult();
}
SSAValue *visitDimExpr(AffineDimExpr expr) {
Value *visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < applyOp.getNumOperands() &&
"affine dim position out of range");
// FIXME: this assumes a certain order of AffineApplyOp operands, the
@ -93,7 +93,7 @@ public:
return applyOp.getOperand(expr.getPosition());
}
SSAValue *visitSymbolExpr(AffineSymbolExpr expr) {
Value *visitSymbolExpr(AffineSymbolExpr expr) {
// FIXME: this assumes a certain order of AffineApplyOp operands, the
// cleaner interface would be to separate them at the op level.
assert(expr.getPosition() + applyOp.getAffineMap().getNumDims() <
@ -114,8 +114,8 @@ private:
// Given an affine expression `expr` extracted from `op`, build the sequence of
// primitive instructions that correspond to the affine expression in the
// `builder`.
static SSAValue *expandAffineExpr(FuncBuilder *builder, AffineExpr expr,
AffineApplyOp *op) {
static mlir::Value *expandAffineExpr(FuncBuilder *builder, AffineExpr expr,
AffineApplyOp *op) {
auto expander = AffineApplyExpander(builder, op);
return expander.visit(expr);
}
@ -127,7 +127,7 @@ bool mlir::expandAffineApply(AffineApplyOp *op) {
FuncBuilder builder(op->getOperation());
auto affineMap = op->getAffineMap();
for (auto numberedExpr : llvm::enumerate(affineMap.getResults())) {
SSAValue *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
Value *expanded = expandAffineExpr(&builder, numberedExpr.value(), op);
if (!expanded)
return true;
op->getResult(numberedExpr.index())->replaceAllUsesWith(expanded);

View File

@ -31,7 +31,6 @@
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
/// Return true if this operation dereferences one or more memref's.
@ -61,13 +60,12 @@ static bool isMemRefDereferencingOp(const Operation &op) {
// extra operands, note that 'indexRemap' would just be applied to the existing
// indices (%i, %j).
//
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// TODO(mlir-team): extend this for Value/ CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
MLValue *newMemRef,
ArrayRef<SSAValue *> extraIndices,
bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<SSAValue *> extraOperands,
ArrayRef<Value *> extraOperands,
const Statement *domStmtFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
@ -128,16 +126,15 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
// operation.
assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((cast<MLValue>(extraIndex)->isValidDim() ||
cast<MLValue>(extraIndex)->isValidSymbol()) &&
assert((extraIndex->isValidDim() || extraIndex->isValidSymbol()) &&
"invalid memory op index");
state.operands.push_back(cast<MLValue>(extraIndex));
state.operands.push_back(extraIndex);
}
// Construct new indices as a remap of the old ones if a remapping has been
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<SSAValue *, 4> remapOperands;
SmallVector<Value *, 4> remapOperands;
remapOperands.reserve(oldMemRefRank + extraOperands.size());
remapOperands.insert(remapOperands.end(), extraOperands.begin(),
extraOperands.end());
@ -149,11 +146,11 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
remapOperands);
// Remapped indices.
for (auto *index : remapOp->getOperation()->getResults())
state.operands.push_back(cast<MLValue>(index));
state.operands.push_back(index);
} else {
// No remapping specified.
for (auto *index : remapOperands)
state.operands.push_back(cast<MLValue>(index));
state.operands.push_back(index);
}
// Insert the remaining operands unmodified.
@ -191,9 +188,9 @@ bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
// composed AffineApplyOp are returned in output parameter 'results'.
OperationStmt *
mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<MLValue *> operands,
ArrayRef<Value *> operands,
ArrayRef<OperationStmt *> affineApplyOps,
SmallVectorImpl<SSAValue *> *results) {
SmallVectorImpl<Value *> *results) {
// Create identity map with same number of dimensions as number of operands.
auto map = builder->getMultiDimIdentityMap(operands.size());
// Initialize AffineValueMap with identity map.
@ -208,7 +205,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
// Compose affine maps from all ancestor AffineApplyOps.
// Create new AffineApplyOp from 'valueMap'.
unsigned numOperands = valueMap.getNumOperands();
SmallVector<SSAValue *, 4> outOperands(numOperands);
SmallVector<Value *, 4> outOperands(numOperands);
for (unsigned i = 0; i < numOperands; ++i) {
outOperands[i] = valueMap.getOperand(i);
}
@ -252,7 +249,7 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
/// otherwise.
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// Collect all operands that are results of affine apply ops.
SmallVector<MLValue *, 4> subOperands;
SmallVector<Value *, 4> subOperands;
subOperands.reserve(opStmt->getNumOperands());
for (auto *operand : opStmt->getOperands()) {
auto *defStmt = operand->getDefiningStmt();
@ -285,7 +282,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
return nullptr;
FuncBuilder builder(opStmt);
SmallVector<SSAValue *, 4> results;
SmallVector<Value *, 4> results;
auto *affineApplyStmt = createComposedAffineApplyOp(
&builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
assert(results.size() == subOperands.size() &&
@ -295,7 +292,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// affine apply op above instead of existing ones (subOperands). So, they
// differ from opStmt's operands only for those operands in 'subOperands', for
// which they will be replaced by the corresponding one from 'results'.
SmallVector<MLValue *, 4> newOperands(opStmt->getOperands());
SmallVector<Value *, 4> newOperands(opStmt->getOperands());
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
// Replace the subOperands from among the new operands.
unsigned j, f;
@ -304,7 +301,7 @@ OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
break;
}
if (j < subOperands.size()) {
newOperands[i] = cast<MLValue>(results[j]);
newOperands[i] = results[j];
}
}
@ -326,7 +323,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// into any uses which are AffineApplyOps.
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
++resultIndex) {
const MLValue *result = opStmt->getResult(resultIndex);
const Value *result = opStmt->getResult(resultIndex);
for (auto it = result->use_begin(); it != result->use_end();) {
StmtOperand &use = *(it++);
auto *useStmt = use.getOwner();
@ -347,7 +344,7 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// Create new AffineApplyOp from 'valueMap'.
unsigned numOperands = valueMap.getNumOperands();
SmallVector<SSAValue *, 4> operands(numOperands);
SmallVector<Value *, 4> operands(numOperands);
for (unsigned i = 0; i < numOperands; ++i) {
operands[i] = valueMap.getOperand(i);
}

View File

@ -27,8 +27,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
@ -740,8 +738,8 @@ struct VectorizationState {
DenseSet<OperationStmt *> vectorizedSet;
// Map of old scalar OperationStmt to new vectorized OperationStmt.
DenseMap<OperationStmt *, OperationStmt *> vectorizationMap;
// Map of old scalar MLValue to new vectorized MLValue.
DenseMap<const MLValue *, MLValue *> replacementMap;
// Map of old scalar Value to new vectorized Value.
DenseMap<const Value *, Value *> replacementMap;
// The strategy drives which loop to vectorize by which amount.
const VectorizationStrategy *strategy;
// Use-def roots. These represent the starting points for the worklist in the
@ -761,7 +759,7 @@ struct VectorizationState {
void registerTerminator(OperationStmt *stmt);
private:
void registerReplacement(const SSAValue *key, SSAValue *value);
void registerReplacement(const Value *key, Value *value);
};
} // end namespace
@ -802,12 +800,9 @@ void VectorizationState::finishVectorizationPattern() {
}
}
void VectorizationState::registerReplacement(const SSAValue *key,
SSAValue *value) {
assert(replacementMap.count(cast<MLValue>(key)) == 0 &&
"replacement already registered");
replacementMap.insert(
std::make_pair(cast<MLValue>(key), cast<MLValue>(value)));
void VectorizationState::registerReplacement(const Value *key, Value *value) {
assert(replacementMap.count(key) == 0 && "replacement already registered");
replacementMap.insert(std::make_pair(key, value));
}
////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. ////
@ -825,7 +820,7 @@ void VectorizationState::registerReplacement(const SSAValue *key,
/// Such special cases force us to delay the vectorization of the stores
/// until the last step. Here we merely register the store operation.
template <typename LoadOrStoreOpPointer>
static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
VectorizationState *state) {
auto memRefType =
memoryOp->getMemRef()->getType().template cast<MemRefType>();
@ -850,8 +845,7 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
MLFuncBuilder b(opStmt);
auto transfer = b.create<VectorTransferReadOp>(
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()),
permutationMap);
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
state->registerReplacement(opStmt,
cast<OperationStmt>(transfer->getOperation()));
} else {
@ -970,8 +964,8 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr.
static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
Type type) {
static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
Type type) {
if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) {
return nullptr;
@ -988,7 +982,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
{make_pair(Identifier::get("value", b.getContext()), attr)});
auto *splat = cast<OperationStmt>(b.createOperation(state));
return cast<MLValue>(splat->getResult(0));
return splat->getResult(0);
}
/// Returns a uniqu'ed VectorType.
@ -996,7 +990,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
/// vectorizedSet, just returns the type of `v`.
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
/// and of elemental type the type of `v`.
static Type getVectorType(SSAValue *v, const VectorizationState &state) {
static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
@ -1028,23 +1022,23 @@ static Type getVectorType(SSAValue *v, const VectorizationState &state) {
/// vectorization is possible with the above logic. Returns nullptr otherwise.
///
/// TODO(ntv): handle more complex cases.
static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt,
VectorizationState *state) {
static Value *vectorizeOperand(Value *operand, Statement *stmt,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
auto *definingStatement = cast<OperationStmt>(operand->getDefiningStmt());
// 1. If this value has already been vectorized this round, we are done.
if (state->vectorizedSet.count(definingStatement) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
return cast<MLValue>(operand);
return operand;
}
// 1.b. Delayed on-demand replacement of a use.
// Note that we cannot just call replaceAllUsesWith because it may result
// in ops with mixed types, for ops whose operands have not all yet
// been vectorized. This would be invalid IR.
auto it = state->replacementMap.find(cast<MLValue>(operand));
auto it = state->replacementMap.find(operand);
if (it != state->replacementMap.end()) {
auto *res = cast<MLValue>(it->second);
auto *res = it->second;
LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
LLVM_DEBUG(res->print(dbgs()));
return res;
@ -1089,7 +1083,7 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
auto *memRef = store->getMemRef();
auto *value = store->getValueToStore();
auto *vectorValue = vectorizeOperand(value, opStmt, state);
auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices());
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
MLFuncBuilder b(opStmt);
auto permutationMap =
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
@ -1104,14 +1098,14 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
return res;
}
auto types = map([state](SSAValue *v) { return getVectorType(v, *state); },
auto types = map([state](Value *v) { return getVectorType(v, *state); },
opStmt->getResults());
auto vectorizeOneOperand = [opStmt, state](SSAValue *op) -> SSAValue * {
auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
return vectorizeOperand(op, opStmt, state);
};
auto operands = map(vectorizeOneOperand, opStmt->getOperands());
// Check whether a single operand is null. If so, vectorization failed.
bool success = llvm::all_of(operands, [](SSAValue *op) { return op; });
bool success = llvm::all_of(operands, [](Value *op) { return op; });
if (!success) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
return nullptr;
@ -1207,7 +1201,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
continue;
}
MLFuncBuilder builder(loop); // builder to insert in place of loop
DenseMap<const MLValue *, MLValue *> nomap;
DenseMap<const Value *, Value *> nomap;
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
@ -1229,8 +1223,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
// Form the root operationsthat have been set in the replacementMap.
// For now, these roots are the loads for which vector_transfer_read
// operations have been inserted.
auto getDefiningOperation = [](const MLValue *val) {
return const_cast<MLValue *>(val)->getDefiningOperation();
auto getDefiningOperation = [](const Value *val) {
return const_cast<Value *>(val)->getDefiningOperation();
};
using ReferenceTy = decltype(*(state.replacementMap.begin()));
auto getKey = [](ReferenceTy it) { return it.first; };

View File

@ -288,10 +288,10 @@ void OpEmitter::emitAttrGetters() {
}
void OpEmitter::emitNamedOperands() {
const auto operandMethods = R"( SSAValue *{0}() {
const auto operandMethods = R"( Value *{0}() {
return this->getOperation()->getOperand({1});
}
const SSAValue *{0}() const {
const Value *{0}() const {
return this->getOperation()->getOperand({1});
}
)";
@ -329,7 +329,7 @@ void OpEmitter::emitBuilder() {
// Emit parameters for all operands
for (const auto &pair : operands)
os << ", SSAValue* " << pair.first;
os << ", Value* " << pair.first;
// Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes
@ -369,7 +369,7 @@ void OpEmitter::emitBuilder() {
// Signature
os << " static void build(Builder* builder, OperationState* result, "
<< "ArrayRef<Type> resultTypes, ArrayRef<SSAValue*> args, "
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
"ArrayRef<NamedAttribute> attributes) {\n";
// Result types