Re-land "[mlir] Add integer range inference analysis""
This reverts commit4e5ce2056e
. This relands commit1350c9887d
. Reinstates the range analysis with the build issue fixed. Differential Revision: https://reviews.llvm.org/D126926
This commit is contained in:
parent
7e48dae5a1
commit
95aff23e29
|
@ -0,0 +1,41 @@
|
|||
//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares the dataflow analysis class for integer range inference
|
||||
// so that it can be used in transformations over the `arith` dialect such as
|
||||
// branch elimination or signed->unsigned rewriting
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
|
||||
#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
class IntRangeAnalysisImpl;
|
||||
} // end namespace detail
|
||||
|
||||
class IntRangeAnalysis {
|
||||
public:
|
||||
/// Analyze all operations rooted under (but not including)
|
||||
/// `topLevelOperation`.
|
||||
IntRangeAnalysis(Operation *topLevelOperation);
|
||||
IntRangeAnalysis(IntRangeAnalysis &&other);
|
||||
~IntRangeAnalysis();
|
||||
|
||||
/// Get inferred range for value `v` if one exists.
|
||||
Optional<ConstantIntRanges> getResult(Value v);
|
||||
|
||||
private:
|
||||
std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
|
@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
|
|||
add_mlir_interface(ControlFlowInterfaces)
|
||||
add_mlir_interface(CopyOpInterface)
|
||||
add_mlir_interface(DerivedAttributeOpInterface)
|
||||
add_mlir_interface(InferIntRangeInterface)
|
||||
add_mlir_interface(InferTypeOpInterface)
|
||||
add_mlir_interface(LoopLikeInterface)
|
||||
add_mlir_interface(SideEffectInterfaces)
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains definitions of the integer range inference interface
|
||||
// defined in `InferIntRange.td`
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
|
||||
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
/// A set of arbitrary-precision integers representing bounds on a given integer
|
||||
/// value. These bounds are inclusive on both ends, so
|
||||
/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
|
||||
/// the unsigned and signed interpretations of values in order to enable more
|
||||
/// precice inference of the interplay between operations with signed and
|
||||
/// unsigned semantics.
|
||||
class ConstantIntRanges {
|
||||
public:
|
||||
/// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
|
||||
/// Non-integer values should be bounded by APInts of bitwidth 0.
|
||||
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
|
||||
const APInt &smax)
|
||||
: uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
|
||||
assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
|
||||
umaxVal.getBitWidth() == sminVal.getBitWidth() &&
|
||||
sminVal.getBitWidth() == smaxVal.getBitWidth() &&
|
||||
"All bounds in the ranges must have the same bitwidth");
|
||||
}
|
||||
|
||||
bool operator==(const ConstantIntRanges &other) const;
|
||||
|
||||
/// The minimum value of an integer when it is interpreted as unsigned.
|
||||
const APInt &umin() const;
|
||||
|
||||
/// The maximum value of an integer when it is interpreted as unsigned.
|
||||
const APInt &umax() const;
|
||||
|
||||
/// The minimum value of an integer when it is interpreted as signed.
|
||||
const APInt &smin() const;
|
||||
|
||||
/// The maximum value of an integer when it is interpreted as signed.
|
||||
const APInt &smax() const;
|
||||
|
||||
/// Return the bitwidth that should be used for integer ranges describing
|
||||
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
|
||||
/// this is the internal storage bitwidth of `index` attributes, and for
|
||||
/// non-integer types this is 0.
|
||||
static unsigned getStorageBitwidth(Type type);
|
||||
|
||||
/// Create an `IntRangeAttrs` where `min` is both the signed and unsigned
|
||||
/// minimum and `max` is both the signed and unsigned maximum.
|
||||
static ConstantIntRanges range(const APInt &min, const APInt &max);
|
||||
|
||||
/// Create an `IntRangeAttrs` with the signed minimum and maximum equal
|
||||
/// to `smin` and `smax`, where the unsigned bounds are constructed from the
|
||||
/// signed ones if they correspond to a contigious range of bit patterns when
|
||||
/// viewed as unsigned values and are left at [0, int_max()] otherwise.
|
||||
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax);
|
||||
|
||||
/// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal
|
||||
/// to `umin` and `umax` and the signed part equal to `umin` and `umax`
|
||||
/// unless the sign bit changes between the minimum and maximum.
|
||||
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax);
|
||||
|
||||
/// Returns the union (computed separately for signed and unsigned bounds)
|
||||
/// of `a` and `b`.
|
||||
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const;
|
||||
|
||||
/// If either the signed or unsigned interpretations of the range
|
||||
/// indicate that the value it bounds is a constant, return that constant
|
||||
/// value.
|
||||
Optional<APInt> getConstantValue() const;
|
||||
|
||||
friend raw_ostream &operator<<(raw_ostream &os,
|
||||
const ConstantIntRanges &range);
|
||||
|
||||
private:
|
||||
APInt uminVal, umaxVal, sminVal, smaxVal;
|
||||
};
|
||||
|
||||
/// The type of the `setResultRanges` callback provided to ops implementing
|
||||
/// InferIntRangeInterface. It should be called once for each integer result
|
||||
/// value and be passed the ConstantIntRanges corresponding to that value.
|
||||
using SetIntRangeFn = function_ref<void(Value, const ConstantIntRanges &)>;
|
||||
} // end namespace mlir
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h.inc"
|
||||
|
||||
#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
|
|
@ -0,0 +1,52 @@
|
|||
//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===-----------------------------------------------------===//
|
||||
//
|
||||
// Defines the interface for range analysis on scalar integers
|
||||
//
|
||||
//===-----------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE
|
||||
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
|
||||
let description = [{
|
||||
Allows operations to participate in range analysis for scalar integer values by
|
||||
providing a methods that allows them to specify lower and upper bounds on their
|
||||
result(s) given lower and upper bounds on their input(s) if known.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Infer the bounds on the results of this op given the bounds on its arguments.
|
||||
For each result value or block argument (that isn't a branch argument,
|
||||
since the dataflow analysis handles those case), the method should call
|
||||
`setValueRange` with that `Value` as an argument. When `setValueRange`
|
||||
is not called for some value, it will recieve a default value of the mimimum
|
||||
and maximum values forits type (the unbounded range).
|
||||
|
||||
When called on an op that also implements the RegionBranchOpInterface
|
||||
or BranchOpInterface, this method should not attempt to infer the values
|
||||
of the branch results, as this will be handled by the analyses that use
|
||||
this interface.
|
||||
|
||||
This function will only be called when at least one result of the op is a
|
||||
scalar integer value or the op has a region.
|
||||
|
||||
`argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS
|
||||
order. Non-integer arguments will have the an unbounded range of width-0
|
||||
APInts in their `argRanges` element.
|
||||
}],
|
||||
"void", "inferResultRanges", (ins
|
||||
"::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges,
|
||||
"::mlir::SetIntRangeFn":$setResultRanges)
|
||||
>];
|
||||
}
|
||||
#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE
|
|
@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
CallGraph.cpp
|
||||
DataFlowAnalysis.cpp
|
||||
DataLayoutAnalysis.cpp
|
||||
IntRangeAnalysis.cpp
|
||||
Liveness.cpp
|
||||
SliceAnalysis.cpp
|
||||
|
||||
|
@ -16,6 +17,7 @@ add_mlir_library(MLIRAnalysis
|
|||
CallGraph.cpp
|
||||
DataFlowAnalysis.cpp
|
||||
DataLayoutAnalysis.cpp
|
||||
IntRangeAnalysis.cpp
|
||||
Liveness.cpp
|
||||
SliceAnalysis.cpp
|
||||
|
||||
|
@ -31,7 +33,9 @@ add_mlir_library(MLIRAnalysis
|
|||
MLIRCallInterfaces
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRDataLayoutInterfaces
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRLoopLikeInterface
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
|
|
@ -359,11 +359,20 @@ void ForwardDataFlowSolver::visitOperation(Operation *op) {
|
|||
if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
|
||||
return visitRegionBranchOperation(branch, operandLattices);
|
||||
|
||||
// If we can't, conservatively mark all regions as executable.
|
||||
// TODO: Let the `visitOperation` method decide how to propagate
|
||||
// information to the block arguments.
|
||||
for (Region ®ion : op->getRegions())
|
||||
markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
analysis.visitNonControlFlowArguments(op, RegionSuccessor(®ion),
|
||||
operandLattices);
|
||||
// `visitNonControlFlowArguments` is required to define all of the region
|
||||
// argument lattices.
|
||||
assert(llvm::none_of(
|
||||
region.getArguments(),
|
||||
[&](Value value) {
|
||||
return analysis.getLatticeElement(value).isUninitialized();
|
||||
}) &&
|
||||
"expected `visitNonControlFlowArguments` to define all argument "
|
||||
"lattices");
|
||||
markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
// If this op produces no results, it can't produce any constants.
|
||||
|
@ -567,12 +576,45 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
|
|||
if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
|
||||
return;
|
||||
|
||||
// If the branch is a RegionBranchTerminatorOpInterface,
|
||||
// construct the set of operand lattices as the set of non control-flow
|
||||
// arguments of the parent and the values this op returns. This allows
|
||||
// for the correct lattices to be passed to getSuccessorsForOperands()
|
||||
// in cases such as scf.while.
|
||||
ArrayRef<AbstractLatticeElement *> branchOpLattices = operandLattices;
|
||||
SmallVector<AbstractLatticeElement *, 0> parentLattices;
|
||||
if (auto regionTerminator =
|
||||
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
|
||||
parentLattices.reserve(regionInterface->getNumOperands());
|
||||
for (Value parentOperand : regionInterface->getOperands()) {
|
||||
AbstractLatticeElement *operandLattice =
|
||||
analysis.lookupLatticeElement(parentOperand);
|
||||
if (!operandLattice || operandLattice->isUninitialized())
|
||||
return;
|
||||
parentLattices.push_back(operandLattice);
|
||||
}
|
||||
unsigned regionNumber = parentRegion->getRegionNumber();
|
||||
OperandRange iterArgs =
|
||||
regionInterface.getSuccessorEntryOperands(regionNumber);
|
||||
OperandRange terminatorArgs =
|
||||
regionTerminator.getSuccessorOperands(regionNumber);
|
||||
assert(iterArgs.size() == terminatorArgs.size() &&
|
||||
"Number of iteration arguments for region should equal number of "
|
||||
"those arguments defined by terminator");
|
||||
if (!iterArgs.empty()) {
|
||||
unsigned iterStart = iterArgs.getBeginOperandIndex();
|
||||
unsigned terminatorStart = terminatorArgs.getBeginOperandIndex();
|
||||
for (unsigned i = 0, e = iterArgs.size(); i < e; ++i)
|
||||
parentLattices[iterStart + i] = operandLattices[terminatorStart + i];
|
||||
}
|
||||
branchOpLattices = parentLattices;
|
||||
}
|
||||
// Query the set of successors of the current region using the current
|
||||
// optimistic lattice state.
|
||||
SmallVector<RegionSuccessor, 1> regionSuccessors;
|
||||
analysis.getSuccessorsForOperands(regionInterface,
|
||||
parentRegion->getRegionNumber(),
|
||||
operandLattices, regionSuccessors);
|
||||
branchOpLattices, regionSuccessors);
|
||||
if (regionSuccessors.empty())
|
||||
return;
|
||||
|
||||
|
@ -584,7 +626,7 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
|
|||
// region index (if any).
|
||||
return *getRegionBranchSuccessorOperands(op, regionIndex);
|
||||
};
|
||||
return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
|
||||
return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices,
|
||||
getOperands);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,325 @@
|
|||
//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the dataflow analysis class for integer range inference
|
||||
// which is used in transformations over the `arith` dialect such as
|
||||
// branch elimination or signed->unsigned rewriting
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/IntRangeAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "int-range-analysis"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A wrapper around ConstantIntRanges that provides the lattice functions
|
||||
/// expected by dataflow analysis.
|
||||
struct IntRangeLattice {
|
||||
IntRangeLattice(const ConstantIntRanges &value) : value(value){};
|
||||
IntRangeLattice(ConstantIntRanges &&value) : value(value){};
|
||||
|
||||
bool operator==(const IntRangeLattice &other) const {
|
||||
return value == other.value;
|
||||
}
|
||||
|
||||
/// wrapper around rangeUnion()
|
||||
static IntRangeLattice join(const IntRangeLattice &a,
|
||||
const IntRangeLattice &b) {
|
||||
return a.value.rangeUnion(b.value);
|
||||
}
|
||||
|
||||
/// Creates a range with bitwidth 0 to represent that we don't know if the
|
||||
/// value being marked overdefined is even an integer.
|
||||
static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
|
||||
APInt noIntValue = APInt::getZeroWidth();
|
||||
return ConstantIntRanges::range(noIntValue, noIntValue);
|
||||
}
|
||||
|
||||
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
|
||||
/// range that is used to mark the value v as unable to be analyzed further,
|
||||
/// where t is the type of v.
|
||||
static IntRangeLattice getPessimisticValueState(Value v) {
|
||||
unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
|
||||
APInt umin = APInt::getMinValue(width);
|
||||
APInt umax = APInt::getMaxValue(width);
|
||||
APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
|
||||
APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
|
||||
return ConstantIntRanges{umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
ConstantIntRanges value;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
|
||||
using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
|
||||
|
||||
public:
|
||||
/// Define bounds on the results or block arguments of the operation
|
||||
/// based on the bounds on the arguments given in `operands`
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
|
||||
|
||||
/// Skip regions of branch ops when we can statically infer constant
|
||||
/// values for operands to the branch op and said op tells us it's safe to do
|
||||
/// so.
|
||||
LogicalResult
|
||||
getSuccessorsForOperands(BranchOpInterface branch,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
|
||||
SmallVectorImpl<Block *> &successors) final;
|
||||
|
||||
/// Skip regions of branch or loop ops when we can statically infer constant
|
||||
/// values for operands to the branch op and said op tells us it's safe to do
|
||||
/// so.
|
||||
void
|
||||
getSuccessorsForOperands(RegionBranchOpInterface branch,
|
||||
Optional<unsigned> sourceIndex,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
|
||||
SmallVectorImpl<RegionSuccessor> &successors) final;
|
||||
|
||||
/// Call the InferIntRangeInterface implementation for region-using ops
|
||||
/// that implement it, and infer the bounds of loop induction variables
|
||||
/// for ops that implement LoopLikeOPInterface.
|
||||
ChangeResult visitNonControlFlowArguments(
|
||||
Operation *op, const RegionSuccessor ®ion,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
/// Given the results of getConstant{Lower,Upper}Bound()
|
||||
/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
|
||||
/// that result if possible.
|
||||
static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
|
||||
Type boundType,
|
||||
detail::IntRangeAnalysisImpl &analysis,
|
||||
bool getUpper) {
|
||||
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
|
||||
if (loopBound.hasValue()) {
|
||||
if (loopBound->is<Attribute>()) {
|
||||
if (auto bound =
|
||||
loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
|
||||
return bound.getValue();
|
||||
} else if (loopBound->is<Value>()) {
|
||||
LatticeElement<IntRangeLattice> *lattice =
|
||||
analysis.lookupLatticeElement(loopBound->get<Value>());
|
||||
if (lattice != nullptr)
|
||||
return getUpper ? lattice->getValue().value.smax()
|
||||
: lattice->getValue().value.smin();
|
||||
}
|
||||
}
|
||||
return getUpper ? APInt::getSignedMaxValue(width)
|
||||
: APInt::getSignedMinValue(width);
|
||||
}
|
||||
|
||||
ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
|
||||
Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
// Ignore non-integer outputs - return early if the op has no scalar
|
||||
// integer results
|
||||
bool hasIntegerResult = false;
|
||||
for (Value v : op->getResults()) {
|
||||
if (v.getType().isIntOrIndex())
|
||||
hasIntegerResult = true;
|
||||
else
|
||||
result |= markAllPessimisticFixpoint(v);
|
||||
}
|
||||
if (!hasIntegerResult)
|
||||
return result;
|
||||
|
||||
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
|
||||
LLVM_DEBUG(inferrable->print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n");
|
||||
SmallVector<ConstantIntRanges> argRanges(
|
||||
llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
|
||||
return val->getValue().value;
|
||||
}));
|
||||
|
||||
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
|
||||
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
|
||||
Optional<IntRangeLattice> oldRange;
|
||||
if (!lattice.isUninitialized())
|
||||
oldRange = lattice.getValue();
|
||||
result |= lattice.join(IntRangeLattice(attrs));
|
||||
|
||||
// Catch loop results with loop variant bounds and conservatively make
|
||||
// them [-inf, inf] so we don't circle around infinitely often (because
|
||||
// the dataflow analysis in MLIR doesn't attempt to work out trip counts
|
||||
// and often can't).
|
||||
bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
|
||||
return op->hasTrait<OpTrait::IsTerminator>();
|
||||
});
|
||||
if (isYieldedResult && oldRange.hasValue() &&
|
||||
!(lattice.getValue() == *oldRange)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
|
||||
result |= lattice.markPessimisticFixpoint();
|
||||
}
|
||||
};
|
||||
|
||||
inferrable.inferResultRanges(argRanges, joinCallback);
|
||||
for (Value opResult : op->getResults()) {
|
||||
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
|
||||
// setResultRange() not called, make pessimistic.
|
||||
if (lattice.isUninitialized())
|
||||
result |= lattice.markPessimisticFixpoint();
|
||||
}
|
||||
} else if (op->getNumRegions() == 0) {
|
||||
// No regions + no result inference method -> unbounded results (ex. memory
|
||||
// ops)
|
||||
result |= markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
|
||||
BranchOpInterface branch,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
|
||||
SmallVectorImpl<Block *> &successors) {
|
||||
auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
|
||||
Optional<APInt> maybeConstValue =
|
||||
enumPair.value()->getValue().value.getConstantValue();
|
||||
|
||||
if (maybeConstValue) {
|
||||
return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
|
||||
*maybeConstValue);
|
||||
}
|
||||
return {};
|
||||
};
|
||||
SmallVector<Attribute> inferredConsts(
|
||||
llvm::map_range(llvm::enumerate(operands), toConstantAttr));
|
||||
if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
|
||||
successors.push_back(singleSucc);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
|
||||
RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands,
|
||||
SmallVectorImpl<RegionSuccessor> &successors) {
|
||||
auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
|
||||
Optional<APInt> maybeConstValue =
|
||||
enumPair.value()->getValue().value.getConstantValue();
|
||||
|
||||
if (maybeConstValue) {
|
||||
return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
|
||||
*maybeConstValue);
|
||||
}
|
||||
return {};
|
||||
};
|
||||
SmallVector<Attribute> inferredConsts(
|
||||
llvm::map_range(llvm::enumerate(operands), toConstantAttr));
|
||||
branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
|
||||
}
|
||||
|
||||
ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
|
||||
Operation *op, const RegionSuccessor ®ion,
|
||||
ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
|
||||
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
|
||||
LLVM_DEBUG(inferrable->print(llvm::dbgs()));
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n");
|
||||
SmallVector<ConstantIntRanges> argRanges(
|
||||
llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
|
||||
return val->getValue().value;
|
||||
}));
|
||||
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
|
||||
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
|
||||
Optional<IntRangeLattice> oldRange;
|
||||
if (!lattice.isUninitialized())
|
||||
oldRange = lattice.getValue();
|
||||
result |= lattice.join(IntRangeLattice(attrs));
|
||||
|
||||
// Catch loop results with loop variant bounds and conservatively make
|
||||
// them [-inf, inf] so we don't circle around infinitely often (because
|
||||
// the dataflow analysis in MLIR doesn't attempt to work out trip counts
|
||||
// and often can't).
|
||||
bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
|
||||
return op->hasTrait<OpTrait::IsTerminator>();
|
||||
});
|
||||
if (isYieldedValue && oldRange.hasValue() &&
|
||||
!(lattice.getValue() == *oldRange)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
|
||||
result |= lattice.markPessimisticFixpoint();
|
||||
}
|
||||
};
|
||||
|
||||
inferrable.inferResultRanges(argRanges, joinCallback);
|
||||
for (Value regionArg : region.getSuccessor()->getArguments()) {
|
||||
LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
|
||||
// setResultRange() not called, make pessimistic.
|
||||
if (lattice.isUninitialized())
|
||||
result |= lattice.markPessimisticFixpoint();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Infer bounds for loop arguments that have static bounds
|
||||
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
|
||||
Optional<Value> iv = loop.getSingleInductionVar();
|
||||
if (!iv.hasValue()) {
|
||||
return ForwardDataFlowAnalysis<
|
||||
IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
|
||||
}
|
||||
Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
|
||||
Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
|
||||
Optional<OpFoldResult> step = loop.getSingleStep();
|
||||
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
|
||||
/*getUpper=*/false);
|
||||
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
|
||||
/*getUpper=*/true);
|
||||
// Assume positivity for uniscoverable steps by way of getUpper = true.
|
||||
APInt stepVal =
|
||||
getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
|
||||
|
||||
if (stepVal.isNegative()) {
|
||||
std::swap(min, max);
|
||||
} else {
|
||||
// Correct the upper bound by subtracting 1 so that it becomes a <= bound,
|
||||
// because loops do not generally include their upper bound.
|
||||
max -= 1;
|
||||
}
|
||||
|
||||
LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
|
||||
return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
|
||||
}
|
||||
return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
|
||||
op, region, operands);
|
||||
}
|
||||
|
||||
IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
|
||||
impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
|
||||
topLevelOperation->getContext());
|
||||
impl->run(topLevelOperation);
|
||||
}
|
||||
|
||||
IntRangeAnalysis::~IntRangeAnalysis() = default;
|
||||
IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
|
||||
|
||||
Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
|
||||
LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
|
||||
if (result == nullptr || result->isUninitialized())
|
||||
return llvm::None;
|
||||
return result->getValue().value;
|
||||
}
|
|
@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
CopyOpInterface.cpp
|
||||
DataLayoutInterfaces.cpp
|
||||
DerivedAttributeOpInterface.cpp
|
||||
InferIntRangeInterface.cpp
|
||||
InferTypeOpInterface.cpp
|
||||
LoopLikeInterface.cpp
|
||||
SideEffectInterfaces.cpp
|
||||
|
@ -35,6 +36,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
|
|||
add_mlir_interface_library(CopyOpInterface)
|
||||
add_mlir_interface_library(DataLayoutInterfaces)
|
||||
add_mlir_interface_library(DerivedAttributeOpInterface)
|
||||
add_mlir_interface_library(InferIntRangeInterface)
|
||||
add_mlir_interface_library(InferTypeOpInterface)
|
||||
add_mlir_interface_library(SideEffectInterfaces)
|
||||
add_mlir_interface_library(TilingInterface)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
|
||||
return umin().getBitWidth() == other.umin().getBitWidth() &&
|
||||
umin() == other.umin() && umax() == other.umax() &&
|
||||
smin() == other.smin() && smax() == other.smax();
|
||||
}
|
||||
|
||||
const APInt &ConstantIntRanges::umin() const { return uminVal; }
|
||||
|
||||
const APInt &ConstantIntRanges::umax() const { return umaxVal; }
|
||||
|
||||
const APInt &ConstantIntRanges::smin() const { return sminVal; }
|
||||
|
||||
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
|
||||
|
||||
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
|
||||
if (type.isIndex())
|
||||
return IndexType::kInternalStorageBitWidth;
|
||||
if (auto integerType = type.dyn_cast<IntegerType>())
|
||||
return integerType.getWidth();
|
||||
// Non-integer types have their bounds stored in width 0 `APInt`s.
|
||||
return 0;
|
||||
}
|
||||
|
||||
ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) {
|
||||
return {min, max, min, max};
|
||||
}
|
||||
|
||||
ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin,
|
||||
const APInt &smax) {
|
||||
unsigned int width = smin.getBitWidth();
|
||||
APInt umin, umax;
|
||||
if (smin.isNonNegative() == smax.isNonNegative()) {
|
||||
umin = smin.ult(smax) ? smin : smax;
|
||||
umax = smin.ugt(smax) ? smin : smax;
|
||||
} else {
|
||||
umin = APInt::getMinValue(width);
|
||||
umax = APInt::getMaxValue(width);
|
||||
}
|
||||
return {umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin,
|
||||
const APInt &umax) {
|
||||
unsigned int width = umin.getBitWidth();
|
||||
APInt smin, smax;
|
||||
if (umin.isNonNegative() == umax.isNonNegative()) {
|
||||
smin = umin.slt(umax) ? umin : umax;
|
||||
smax = umin.sgt(umax) ? umin : umax;
|
||||
} else {
|
||||
smin = APInt::getSignedMinValue(width);
|
||||
smax = APInt::getSignedMaxValue(width);
|
||||
}
|
||||
return {umin, umax, smin, smax};
|
||||
}
|
||||
|
||||
ConstantIntRanges
|
||||
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
|
||||
// "Not an integer" poisons everything and also cannot be fed to comparison
|
||||
// operators.
|
||||
if (umin().getBitWidth() == 0)
|
||||
return *this;
|
||||
if (other.umin().getBitWidth() == 0)
|
||||
return other;
|
||||
|
||||
const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
|
||||
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
|
||||
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
|
||||
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
|
||||
|
||||
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
|
||||
}
|
||||
|
||||
Optional<APInt> ConstantIntRanges::getConstantValue() const {
|
||||
// Note: we need to exclude the trivially-equal width 0 values here.
|
||||
if (umin() == umax() && umin().getBitWidth() != 0)
|
||||
return umin();
|
||||
if (smin() == smax() && smin().getBitWidth() != 0)
|
||||
return smin();
|
||||
return None;
|
||||
}
|
||||
|
||||
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
|
||||
return os << "unsigned : [" << range.umin() << ", " << range.umax()
|
||||
<< "] signed : [" << range.smin() << ", " << range.smax() << "]";
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @constant
|
||||
// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
|
||||
// CHECK: return %[[cst]]
|
||||
func.func @constant() -> index {
|
||||
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
|
||||
smin = 3 : index, smax = 3 : index}
|
||||
func.return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @increment
|
||||
// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
|
||||
// CHECK: return %[[cst]]
|
||||
func.func @increment() -> index {
|
||||
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
|
||||
%1 = test.increment %0
|
||||
func.return %1 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maybe_increment
|
||||
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
|
||||
func.func @maybe_increment(%arg0 : i1) -> index {
|
||||
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
|
||||
smin = 3 : index, smax = 3 : index}
|
||||
%1 = scf.if %arg0 -> index {
|
||||
scf.yield %0 : index
|
||||
} else {
|
||||
%2 = test.increment %0
|
||||
scf.yield %2 : index
|
||||
}
|
||||
%3 = test.reflect_bounds %1
|
||||
func.return %3 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maybe_increment_br
|
||||
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
|
||||
func.func @maybe_increment_br(%arg0 : i1) -> index {
|
||||
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
|
||||
smin = 3 : index, smax = 3 : index}
|
||||
cf.cond_br %arg0, ^bb0, ^bb1
|
||||
^bb0:
|
||||
%1 = test.increment %0
|
||||
cf.br ^bb2(%1 : index)
|
||||
^bb1:
|
||||
cf.br ^bb2(%0 : index)
|
||||
^bb2(%2 : index):
|
||||
%3 = test.reflect_bounds %2
|
||||
func.return %3 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @for_bounds
|
||||
// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index}
|
||||
func.func @for_bounds() -> index {
|
||||
%c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
|
||||
smin = 0 : index, smax = 0 : index}
|
||||
%c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
|
||||
smin = 1 : index, smax = 1 : index}
|
||||
%c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
|
||||
smin = 2 : index, smax = 2 : index}
|
||||
|
||||
%0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
|
||||
scf.yield %arg0 : index
|
||||
}
|
||||
%1 = test.reflect_bounds %0
|
||||
func.return %1 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @no_analysis_of_loop_variants
|
||||
// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
|
||||
func.func @no_analysis_of_loop_variants() -> index {
|
||||
%c0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
|
||||
smin = 0 : index, smax = 0 : index}
|
||||
%c1 = test.with_bounds { umin = 1 : index, umax = 1 : index,
|
||||
smin = 1 : index, smax = 1 : index}
|
||||
%c2 = test.with_bounds { umin = 2 : index, umax = 2 : index,
|
||||
smin = 2 : index, smax = 2 : index}
|
||||
|
||||
%0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
|
||||
%1 = test.increment %arg2
|
||||
scf.yield %1 : index
|
||||
}
|
||||
%2 = test.reflect_bounds %0
|
||||
func.return %2 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @region_args
|
||||
// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index}
|
||||
func.func @region_args() {
|
||||
test.with_bounds_region { umin = 3 : index, umax = 4 : index,
|
||||
smin = 3 : index, smax = 4 : index } %arg0 {
|
||||
%0 = test.reflect_bounds %arg0
|
||||
}
|
||||
func.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_args_unbound
|
||||
// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index}
|
||||
func.func @func_args_unbound(%arg0 : index) -> index {
|
||||
%0 = test.reflect_bounds %arg0
|
||||
func.return %0 : index
|
||||
}
|
|
@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
|
|||
MLIRFunc
|
||||
MLIRFuncTransforms
|
||||
MLIRIR
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRLinalg
|
||||
MLIRLinalgTransforms
|
||||
|
|
|
@ -14,15 +14,21 @@
|
|||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/ExtensibleDialect.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Reducer/ReductionPatternInterface.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
||||
|
@ -1396,6 +1402,67 @@ LogicalResult TestVerifiersOp::verifyRegions() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test InferIntRangeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRanges) {
|
||||
setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
|
||||
}
|
||||
|
||||
ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
// Parse the input argument
|
||||
OpAsmParser::Argument argInfo;
|
||||
argInfo.type = parser.getBuilder().getIndexType();
|
||||
if (failed(parser.parseArgument(argInfo)))
|
||||
return failure();
|
||||
|
||||
// Parse the body region, and reuse the operand info as the argument info.
|
||||
Region *body = result.addRegion();
|
||||
return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
|
||||
}
|
||||
|
||||
void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
p << ' ';
|
||||
p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
|
||||
/*omitType=*/true);
|
||||
p << ' ';
|
||||
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
void TestWithBoundsRegionOp::inferResultRanges(
|
||||
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
|
||||
Value arg = getRegion().getArgument(0);
|
||||
setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
|
||||
}
|
||||
|
||||
void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRanges) {
|
||||
const ConstantIntRanges &range = argRanges[0];
|
||||
APInt one(range.umin().getBitWidth(), 1);
|
||||
setResultRanges(getResult(),
|
||||
{range.umin().uadd_sat(one), range.umax().uadd_sat(one),
|
||||
range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
|
||||
}
|
||||
|
||||
void TestReflectBoundsOp::inferResultRanges(
|
||||
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
|
||||
const ConstantIntRanges &range = argRanges[0];
|
||||
MLIRContext *ctx = getContext();
|
||||
Builder b(ctx);
|
||||
setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
|
||||
setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
|
||||
setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
|
||||
setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
|
||||
setResultRanges(getResult(), range);
|
||||
}
|
||||
|
||||
#include "TestOpEnums.cpp.inc"
|
||||
#include "TestOpInterfaces.cpp.inc"
|
||||
#include "TestOpStructs.cpp.inc"
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/CopyOpInterface.h"
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
|
|||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
include "mlir/Interfaces/InferIntRangeInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -789,7 +790,7 @@ def StringAttrPrettyNameOp
|
|||
def CustomResultsNameOp
|
||||
: TEST_Op<"custom_result_name",
|
||||
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
let arguments = (ins
|
||||
let arguments = (ins
|
||||
Variadic<AnyInteger>:$optional,
|
||||
StrArrayAttr:$names
|
||||
);
|
||||
|
@ -2885,4 +2886,51 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test InferIntRangeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TestWithBoundsOp : TEST_Op<"with_bounds",
|
||||
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
|
||||
NoSideEffect]> {
|
||||
let arguments = (ins IndexAttr:$umin,
|
||||
IndexAttr:$umax,
|
||||
IndexAttr:$smin,
|
||||
IndexAttr:$smax);
|
||||
let results = (outs Index:$fakeVal);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region",
|
||||
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
|
||||
SingleBlock, NoTerminator]> {
|
||||
let arguments = (ins IndexAttr:$umin,
|
||||
IndexAttr:$umax,
|
||||
IndexAttr:$smin,
|
||||
IndexAttr:$smax);
|
||||
// The region has one argument of index type
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def TestIncrementOp : TEST_Op<"increment",
|
||||
[DeclareOpInterfaceMethods<InferIntRangeInterface>,
|
||||
NoSideEffect]> {
|
||||
let arguments = (ins Index:$value);
|
||||
let results = (outs Index:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict $value";
|
||||
}
|
||||
|
||||
def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
|
||||
[DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
|
||||
let arguments = (ins Index:$value,
|
||||
OptionalAttr<IndexAttr>:$umin,
|
||||
OptionalAttr<IndexAttr>:$umax,
|
||||
OptionalAttr<IndexAttr>:$smin,
|
||||
OptionalAttr<IndexAttr>:$smax);
|
||||
let results = (outs Index:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict $value";
|
||||
}
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_library(MLIRTestTransforms
|
|||
TestConstantFold.cpp
|
||||
TestControlFlowSink.cpp
|
||||
TestInlining.cpp
|
||||
TestIntRangeInference.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
@ -10,6 +11,8 @@ add_mlir_library(MLIRTestTransforms
|
|||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRTestDialect
|
||||
MLIRTransforms
|
||||
)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: This pass is needed to test integer range inference until that
|
||||
// functionality has been integrated into SCCP.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/IntRangeAnalysis.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Patterned after SCCP
|
||||
static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
|
||||
OpBuilder &b, OperationFolder &folder,
|
||||
Value value) {
|
||||
Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
|
||||
if (!maybeInferredRange)
|
||||
return failure();
|
||||
const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
|
||||
Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
|
||||
if (!maybeConstValue.hasValue())
|
||||
return failure();
|
||||
|
||||
Operation *maybeDefiningOp = value.getDefiningOp();
|
||||
Dialect *valueDialect =
|
||||
maybeDefiningOp ? maybeDefiningOp->getDialect()
|
||||
: value.getParentRegion()->getParentOp()->getDialect();
|
||||
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
|
||||
Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
|
||||
value.getType(), value.getLoc());
|
||||
if (!constant)
|
||||
return failure();
|
||||
|
||||
value.replaceAllUsesWith(constant);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
|
||||
MutableArrayRef<Region> initialRegions) {
|
||||
SmallVector<Block *> worklist;
|
||||
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
||||
for (Region ®ion : regions)
|
||||
for (Block &block : llvm::reverse(region))
|
||||
worklist.push_back(&block);
|
||||
};
|
||||
|
||||
OpBuilder builder(context);
|
||||
OperationFolder folder(context);
|
||||
|
||||
addToWorklist(initialRegions);
|
||||
while (!worklist.empty()) {
|
||||
Block *block = worklist.pop_back_val();
|
||||
|
||||
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
// Replace any result with constants.
|
||||
bool replacedAll = op.getNumResults() != 0;
|
||||
for (Value res : op.getResults())
|
||||
replacedAll &=
|
||||
succeeded(replaceWithConstant(analysis, builder, folder, res));
|
||||
|
||||
// If all of the results of the operation were replaced, try to erase
|
||||
// the operation completely.
|
||||
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
||||
assert(op.use_empty() && "expected all uses to be replaced");
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add any the regions of this operation to the worklist.
|
||||
addToWorklist(op.getRegions());
|
||||
}
|
||||
|
||||
// Replace any block arguments with constants.
|
||||
builder.setInsertionPointToStart(block);
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
(void)replaceWithConstant(analysis, builder, folder, arg);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestIntRangeInference
|
||||
: PassWrapper<TestIntRangeInference, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
|
||||
|
||||
StringRef getArgument() const final { return "test-int-range-inference"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test integer range inference analysis";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
IntRangeAnalysis analysis(op);
|
||||
rewrite(analysis, op->getContext(), op->getRegions());
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestIntRangeInference() {
|
||||
PassRegistration<TestIntRangeInference>();
|
||||
}
|
||||
} // end namespace test
|
||||
} // end namespace mlir
|
|
@ -79,6 +79,7 @@ void registerTestDynamicPipelinePass();
|
|||
void registerTestExpandMathPass();
|
||||
void registerTestComposeSubView();
|
||||
void registerTestMultiBuffering();
|
||||
void registerTestIntRangeInference();
|
||||
void registerTestIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsPass();
|
||||
void registerTestGenericIRVisitorsInterruptPass();
|
||||
|
@ -175,6 +176,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestExpandMathPass();
|
||||
mlir::test::registerTestComposeSubView();
|
||||
mlir::test::registerTestMultiBuffering();
|
||||
mlir::test::registerTestIntRangeInference();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
mlir::test::registerTestGenericIRVisitorsPass();
|
||||
mlir::test::registerTestInterfaces();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_unittest(MLIRInterfacesTests
|
||||
ControlFlowInterfacesTest.cpp
|
||||
DataLayoutInterfacesTest.cpp
|
||||
InferIntRangeInterfaceTest.cpp
|
||||
InferTypeOpInterfaceTest.cpp
|
||||
)
|
||||
|
||||
|
@ -10,6 +11,7 @@ target_link_libraries(MLIRInterfacesTests
|
|||
MLIRDataLayoutInterfaces
|
||||
MLIRDLTI
|
||||
MLIRFunc
|
||||
MLIRInferIntRangeInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRParser
|
||||
)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include <limits>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
TEST(IntRangeAttrs, BasicConstructors) {
|
||||
APInt zero = APInt::getZero(64);
|
||||
APInt two(64, 2);
|
||||
APInt three(64, 3);
|
||||
ConstantIntRanges boundedAbove(zero, two, zero, three);
|
||||
EXPECT_EQ(boundedAbove.umin(), zero);
|
||||
EXPECT_EQ(boundedAbove.umax(), two);
|
||||
EXPECT_EQ(boundedAbove.smin(), zero);
|
||||
EXPECT_EQ(boundedAbove.smax(), three);
|
||||
}
|
||||
|
||||
TEST(IntRangeAttrs, FromUnsigned) {
|
||||
APInt zero = APInt::getZero(64);
|
||||
APInt maxInt = APInt::getSignedMaxValue(64);
|
||||
APInt minInt = APInt::getSignedMinValue(64);
|
||||
APInt minIntPlusOne = minInt + 1;
|
||||
|
||||
ConstantIntRanges canPortToSigned =
|
||||
ConstantIntRanges::fromUnsigned(zero, maxInt);
|
||||
EXPECT_EQ(canPortToSigned.smin(), zero);
|
||||
EXPECT_EQ(canPortToSigned.smax(), maxInt);
|
||||
|
||||
ConstantIntRanges cantPortToSigned =
|
||||
ConstantIntRanges::fromUnsigned(zero, minInt);
|
||||
EXPECT_EQ(cantPortToSigned.smin(), minInt);
|
||||
EXPECT_EQ(cantPortToSigned.smax(), maxInt);
|
||||
|
||||
ConstantIntRanges signedNegative =
|
||||
ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne);
|
||||
EXPECT_EQ(signedNegative.smin(), minInt);
|
||||
EXPECT_EQ(signedNegative.smax(), minIntPlusOne);
|
||||
}
|
||||
|
||||
TEST(IntRangeAttrs, FromSigned) {
|
||||
APInt zero = APInt::getZero(64);
|
||||
APInt one = zero + 1;
|
||||
APInt negOne = zero - 1;
|
||||
APInt intMax = APInt::getSignedMaxValue(64);
|
||||
APInt intMin = APInt::getSignedMinValue(64);
|
||||
APInt uintMax = APInt::getMaxValue(64);
|
||||
|
||||
ConstantIntRanges noUnsignedBound =
|
||||
ConstantIntRanges::fromSigned(negOne, one);
|
||||
EXPECT_EQ(noUnsignedBound.umin(), zero);
|
||||
EXPECT_EQ(noUnsignedBound.umax(), uintMax);
|
||||
|
||||
ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax);
|
||||
EXPECT_EQ(positive.umin(), one);
|
||||
EXPECT_EQ(positive.umax(), intMax);
|
||||
|
||||
ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne);
|
||||
EXPECT_EQ(negative.umin(), intMin);
|
||||
EXPECT_EQ(negative.umax(), negOne);
|
||||
|
||||
ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one);
|
||||
EXPECT_EQ(preserved.umin(), zero);
|
||||
EXPECT_EQ(preserved.umax(), one);
|
||||
}
|
||||
|
||||
TEST(IntRangeAttrs, Join) {
|
||||
APInt zero = APInt::getZero(64);
|
||||
APInt one = zero + 1;
|
||||
APInt two = zero + 2;
|
||||
APInt intMin = APInt::getSignedMinValue(64);
|
||||
APInt intMax = APInt::getSignedMaxValue(64);
|
||||
APInt uintMax = APInt::getMaxValue(64);
|
||||
|
||||
ConstantIntRanges maximal(zero, uintMax, intMin, intMax);
|
||||
ConstantIntRanges zeroOne(zero, one, zero, one);
|
||||
|
||||
EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal);
|
||||
EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal);
|
||||
|
||||
EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne);
|
||||
|
||||
ConstantIntRanges oneTwo(one, two, one, two);
|
||||
ConstantIntRanges zeroTwo(zero, two, zero, two);
|
||||
EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo);
|
||||
|
||||
ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax);
|
||||
ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
|
||||
EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
|
||||
}
|
Loading…
Reference in New Issue