1200 lines
35 KiB
C++
1200 lines
35 KiB
C++
//===- Merger.cpp - Implementation of iteration lattices ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
namespace mlir {
|
|
namespace sparse_tensor {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Constructors.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
|
: kind(k), val(v), op(o) {
|
|
switch (kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
assert(x != -1u && y == -1u && !v && !o);
|
|
tensor = x;
|
|
break;
|
|
case kInvariant:
|
|
assert(x == -1u && y == -1u && v && !o);
|
|
break;
|
|
case kIndex:
|
|
assert(x != -1u && y == -1u && !v && !o);
|
|
index = x;
|
|
break;
|
|
// Unary operations.
|
|
case kAbsF:
|
|
case kAbsC:
|
|
case kCeilF:
|
|
case kFloorF:
|
|
case kSqrtF:
|
|
case kSqrtC:
|
|
case kExpm1F:
|
|
case kExpm1C:
|
|
case kLog1pF:
|
|
case kLog1pC:
|
|
case kSinF:
|
|
case kSinC:
|
|
case kTanhF:
|
|
case kTanhC:
|
|
case kNegF:
|
|
case kNegC:
|
|
case kNegI:
|
|
case kCIm:
|
|
case kCRe:
|
|
assert(x != -1u && y == -1u && !v && !o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
case kTruncF:
|
|
case kExtF:
|
|
case kCastFS:
|
|
case kCastFU:
|
|
case kCastSF:
|
|
case kCastUF:
|
|
case kCastS:
|
|
case kCastU:
|
|
case kCastIdx:
|
|
case kTruncI:
|
|
case kBitCast:
|
|
assert(x != -1u && y == -1u && v && !o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
case kBinaryBranch:
|
|
assert(x != -1u && y == -1u && !v && o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
case kUnary:
|
|
// No assertion on y can be made, as the branching paths involve both
|
|
// a unary (mapSet) and binary (takeDisj) pathway.
|
|
assert(x != -1u && !v && o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
// Binary operations.
|
|
case kMulF:
|
|
case kMulC:
|
|
case kMulI:
|
|
case kDivF:
|
|
case kDivC:
|
|
case kDivS:
|
|
case kDivU:
|
|
case kAddF:
|
|
case kAddC:
|
|
case kAddI:
|
|
case kSubF:
|
|
case kSubC:
|
|
case kSubI:
|
|
case kAndI:
|
|
case kOrI:
|
|
case kXorI:
|
|
case kShrS:
|
|
case kShrU:
|
|
case kShlI:
|
|
assert(x != -1u && y != -1u && !v && !o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
case kBinary:
|
|
assert(x != -1u && y != -1u && !v && o);
|
|
children.e0 = x;
|
|
children.e1 = y;
|
|
break;
|
|
}
|
|
}
|
|
|
|
LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
|
|
: bits(n, false), simple(), exp(e) {
|
|
bits.set(b);
|
|
}
|
|
|
|
LatPoint::LatPoint(const BitVector &b, unsigned e)
|
|
: bits(b), simple(), exp(e) {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Lattice methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
|
|
Operation *op) {
|
|
unsigned e = tensorExps.size();
|
|
tensorExps.push_back(TensorExp(k, e0, e1, v, op));
|
|
return e;
|
|
}
|
|
|
|
unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
|
|
assert(t < numTensors && i < numLoops);
|
|
unsigned p = latPoints.size();
|
|
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
|
return p;
|
|
}
|
|
|
|
unsigned Merger::addSet() {
|
|
unsigned s = latSets.size();
|
|
latSets.emplace_back(SmallVector<unsigned, 16>());
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
|
|
Operation *op) {
|
|
unsigned p = latPoints.size();
|
|
BitVector nb = BitVector(latPoints[p0].bits);
|
|
nb |= latPoints[p1].bits;
|
|
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
|
|
latPoints.push_back(LatPoint(nb, e));
|
|
return p;
|
|
}
|
|
|
|
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
|
|
unsigned s = addSet();
|
|
for (unsigned p0 : latSets[s0])
|
|
for (unsigned p1 : latSets[s1])
|
|
latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
|
|
unsigned s = takeConj(kind, s0, s1, op);
|
|
// Followed by all in s0.
|
|
for (unsigned p : latSets[s0])
|
|
latSets[s].push_back(p);
|
|
// Map binary 0-y to unary -y.
|
|
// TODO: move this if-else logic into buildLattices
|
|
if (kind == kSubF)
|
|
s1 = mapSet(kNegF, s1);
|
|
else if (kind == kSubC)
|
|
s1 = mapSet(kNegC, s1);
|
|
else if (kind == kSubI)
|
|
s1 = mapSet(kNegI, s1);
|
|
// Followed by all in s1.
|
|
for (unsigned p : latSets[s1])
|
|
latSets[s].push_back(p);
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
|
|
bool includeLeft, Kind ltrans, Operation *opleft,
|
|
bool includeRight, Kind rtrans, Operation *opright) {
|
|
unsigned s = takeConj(kind, s0, s1, orig);
|
|
// Left Region.
|
|
if (includeLeft) {
|
|
if (opleft)
|
|
s0 = mapSet(ltrans, s0, Value(), opleft);
|
|
for (unsigned p : latSets[s0])
|
|
latSets[s].push_back(p);
|
|
}
|
|
// Right Region.
|
|
if (includeRight) {
|
|
if (opright)
|
|
s1 = mapSet(rtrans, s1, Value(), opright);
|
|
for (unsigned p : latSets[s1])
|
|
latSets[s].push_back(p);
|
|
}
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
|
|
assert(kAbsF <= kind && kind <= kUnary);
|
|
unsigned s = addSet();
|
|
for (unsigned p : latSets[s0]) {
|
|
unsigned e = addExp(kind, latPoints[p].exp, v, op);
|
|
latPoints.push_back(LatPoint(latPoints[p].bits, e));
|
|
latSets[s].push_back(latPoints.size() - 1);
|
|
}
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::optimizeSet(unsigned s0) {
|
|
unsigned s = addSet();
|
|
assert(!latSets[s0].empty());
|
|
unsigned p0 = latSets[s0][0];
|
|
for (unsigned p1 : latSets[s0]) {
|
|
bool add = true;
|
|
if (p0 != p1) {
|
|
// Is this a straightforward copy?
|
|
unsigned e = latPoints[p1].exp;
|
|
if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
|
|
continue;
|
|
// Conjunction already covered?
|
|
for (unsigned p2 : latSets[s]) {
|
|
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
|
if (onlyDenseDiff(p2, p1)) {
|
|
add = false;
|
|
break;
|
|
}
|
|
}
|
|
assert(!add || latGT(p0, p1));
|
|
}
|
|
if (add)
|
|
latSets[s].push_back(p1);
|
|
}
|
|
for (unsigned p : latSets[s])
|
|
latPoints[p].simple = simplifyCond(s, p);
|
|
return s;
|
|
}
|
|
|
|
BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
|
|
// First determine if this lattice point is a *singleton*, i.e.,
|
|
// the last point in a lattice, no other is less than this one.
|
|
bool isSingleton = true;
|
|
for (unsigned p1 : latSets[s0]) {
|
|
if (p0 != p1 && latGT(p0, p1)) {
|
|
isSingleton = false;
|
|
break;
|
|
}
|
|
}
|
|
// Now apply the two basic rules.
|
|
BitVector simple = latPoints[p0].bits;
|
|
bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
|
|
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
|
if (simple[b] && !isDim(b, kSparse)) {
|
|
if (reset)
|
|
simple.reset(b);
|
|
reset = true;
|
|
}
|
|
}
|
|
return simple;
|
|
}
|
|
|
|
bool Merger::latGT(unsigned i, unsigned j) const {
|
|
const BitVector &bitsi = latPoints[i].bits;
|
|
const BitVector &bitsj = latPoints[j].bits;
|
|
assert(bitsi.size() == bitsj.size());
|
|
if (bitsi.count() > bitsj.count()) {
|
|
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
|
if (bitsj[b] && !bitsi[b])
|
|
return false;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
|
|
BitVector tmp = latPoints[j].bits;
|
|
tmp ^= latPoints[i].bits;
|
|
return !hasAnyDimOf(tmp, kSparse);
|
|
}
|
|
|
|
bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
|
|
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
|
if (bits[b] && isDim(b, d))
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|
switch (tensorExps[e].kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
return tensorExps[e].tensor == t;
|
|
case kInvariant:
|
|
case kIndex:
|
|
return false;
|
|
// Unary operations.
|
|
case kAbsF:
|
|
case kAbsC:
|
|
case kCeilF:
|
|
case kFloorF:
|
|
case kSqrtF:
|
|
case kSqrtC:
|
|
case kExpm1F:
|
|
case kExpm1C:
|
|
case kLog1pF:
|
|
case kLog1pC:
|
|
case kSinF:
|
|
case kSinC:
|
|
case kTanhF:
|
|
case kTanhC:
|
|
case kNegF:
|
|
case kNegC:
|
|
case kNegI:
|
|
case kTruncF:
|
|
case kExtF:
|
|
case kCastFS:
|
|
case kCastFU:
|
|
case kCastSF:
|
|
case kCastUF:
|
|
case kCastS:
|
|
case kCastU:
|
|
case kCastIdx:
|
|
case kTruncI:
|
|
case kCIm:
|
|
case kCRe:
|
|
case kBitCast:
|
|
return isSingleCondition(t, tensorExps[e].children.e0);
|
|
case kBinaryBranch:
|
|
case kUnary:
|
|
return false;
|
|
// Binary operations.
|
|
case kDivF: // note: x / c only
|
|
case kDivC:
|
|
case kDivS:
|
|
case kDivU:
|
|
assert(!maybeZero(tensorExps[e].children.e1));
|
|
return isSingleCondition(t, tensorExps[e].children.e0);
|
|
case kShrS: // note: x >> inv only
|
|
case kShrU:
|
|
case kShlI:
|
|
assert(isInvariant(tensorExps[e].children.e1));
|
|
return isSingleCondition(t, tensorExps[e].children.e0);
|
|
case kMulF:
|
|
case kMulC:
|
|
case kMulI:
|
|
case kAndI:
|
|
if (isSingleCondition(t, tensorExps[e].children.e0))
|
|
return isSingleCondition(t, tensorExps[e].children.e1) ||
|
|
isInvariant(tensorExps[e].children.e1);
|
|
if (isSingleCondition(t, tensorExps[e].children.e1))
|
|
return isInvariant(tensorExps[e].children.e0);
|
|
return false;
|
|
case kAddF:
|
|
case kAddC:
|
|
case kAddI:
|
|
return isSingleCondition(t, tensorExps[e].children.e0) &&
|
|
isSingleCondition(t, tensorExps[e].children.e1);
|
|
case kSubF:
|
|
case kSubC:
|
|
case kSubI:
|
|
case kOrI:
|
|
case kXorI:
|
|
case kBinary:
|
|
return false;
|
|
}
|
|
llvm_unreachable("unexpected kind");
|
|
}
|
|
|
|
#ifndef NDEBUG
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Print methods (for debugging).
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static const char *kindToOpSymbol(Kind kind) {
|
|
switch (kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
return "tensor";
|
|
case kInvariant:
|
|
return "invariant";
|
|
case kIndex:
|
|
return "index";
|
|
// Unary operations.
|
|
case kAbsF:
|
|
case kAbsC:
|
|
return "abs";
|
|
case kCeilF:
|
|
return "ceil";
|
|
case kFloorF:
|
|
return "floor";
|
|
case kSqrtF:
|
|
case kSqrtC:
|
|
return "sqrt";
|
|
case kExpm1F:
|
|
case kExpm1C:
|
|
return "expm1";
|
|
case kLog1pF:
|
|
case kLog1pC:
|
|
return "log1p";
|
|
case kSinF:
|
|
case kSinC:
|
|
return "sin";
|
|
case kTanhF:
|
|
case kTanhC:
|
|
return "tanh";
|
|
case kNegF:
|
|
case kNegC:
|
|
case kNegI:
|
|
return "-";
|
|
case kTruncF:
|
|
case kExtF:
|
|
case kCastFS:
|
|
case kCastFU:
|
|
case kCastSF:
|
|
case kCastUF:
|
|
case kCastS:
|
|
case kCastU:
|
|
case kCastIdx:
|
|
case kTruncI:
|
|
case kCIm:
|
|
return "complex.im";
|
|
case kCRe:
|
|
return "complex.re";
|
|
case kBitCast:
|
|
return "cast";
|
|
case kBinaryBranch:
|
|
return "binary_branch";
|
|
case kUnary:
|
|
return "unary";
|
|
// Binary operations.
|
|
case kMulF:
|
|
case kMulC:
|
|
case kMulI:
|
|
return "*";
|
|
case kDivF:
|
|
case kDivC:
|
|
case kDivS:
|
|
case kDivU:
|
|
return "/";
|
|
case kAddF:
|
|
case kAddC:
|
|
case kAddI:
|
|
return "+";
|
|
case kSubF:
|
|
case kSubC:
|
|
case kSubI:
|
|
return "-";
|
|
case kAndI:
|
|
return "&";
|
|
case kOrI:
|
|
return "|";
|
|
case kXorI:
|
|
return "^";
|
|
case kShrS:
|
|
return "a>>";
|
|
case kShrU:
|
|
return ">>";
|
|
case kShlI:
|
|
return "<<";
|
|
case kBinary:
|
|
return "binary";
|
|
}
|
|
llvm_unreachable("unexpected kind for symbol");
|
|
}
|
|
|
|
void Merger::dumpExp(unsigned e) const {
|
|
switch (tensorExps[e].kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
if (tensorExps[e].tensor == syntheticTensor)
|
|
llvm::dbgs() << "synthetic_";
|
|
else if (tensorExps[e].tensor == outTensor)
|
|
llvm::dbgs() << "output_";
|
|
llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
|
|
break;
|
|
case kInvariant:
|
|
llvm::dbgs() << "invariant";
|
|
break;
|
|
case kIndex:
|
|
llvm::dbgs() << "index_" << tensorExps[e].index;
|
|
break;
|
|
// Unary operations.
|
|
case kAbsF:
|
|
case kAbsC:
|
|
case kCeilF:
|
|
case kFloorF:
|
|
case kSqrtF:
|
|
case kSqrtC:
|
|
case kExpm1F:
|
|
case kExpm1C:
|
|
case kLog1pF:
|
|
case kLog1pC:
|
|
case kSinF:
|
|
case kSinC:
|
|
case kTanhF:
|
|
case kTanhC:
|
|
case kNegF:
|
|
case kNegC:
|
|
case kNegI:
|
|
case kTruncF:
|
|
case kExtF:
|
|
case kCastFS:
|
|
case kCastFU:
|
|
case kCastSF:
|
|
case kCastUF:
|
|
case kCastS:
|
|
case kCastU:
|
|
case kCastIdx:
|
|
case kTruncI:
|
|
case kCIm:
|
|
case kCRe:
|
|
case kBitCast:
|
|
case kBinaryBranch:
|
|
case kUnary:
|
|
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
|
|
dumpExp(tensorExps[e].children.e0);
|
|
break;
|
|
// Binary operations.
|
|
case kMulF:
|
|
case kMulC:
|
|
case kMulI:
|
|
case kDivF:
|
|
case kDivC:
|
|
case kDivS:
|
|
case kDivU:
|
|
case kAddF:
|
|
case kAddC:
|
|
case kAddI:
|
|
case kSubF:
|
|
case kSubC:
|
|
case kSubI:
|
|
case kAndI:
|
|
case kOrI:
|
|
case kXorI:
|
|
case kShrS:
|
|
case kShrU:
|
|
case kShlI:
|
|
case kBinary:
|
|
llvm::dbgs() << "(";
|
|
dumpExp(tensorExps[e].children.e0);
|
|
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
|
|
dumpExp(tensorExps[e].children.e1);
|
|
llvm::dbgs() << ")";
|
|
}
|
|
}
|
|
|
|
void Merger::dumpLat(unsigned p) const {
|
|
llvm::dbgs() << "lat(";
|
|
dumpBits(latPoints[p].bits);
|
|
llvm::dbgs() << " :";
|
|
dumpBits(latPoints[p].simple);
|
|
llvm::dbgs() << " : ";
|
|
dumpExp(latPoints[p].exp);
|
|
llvm::dbgs() << " )\n";
|
|
}
|
|
|
|
void Merger::dumpSet(unsigned s) const {
|
|
llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
|
|
for (unsigned p : latSets[s]) {
|
|
llvm::dbgs() << " ";
|
|
dumpLat(p);
|
|
}
|
|
llvm::dbgs() << "}\n";
|
|
}
|
|
|
|
void Merger::dumpBits(const BitVector &bits) const {
|
|
for (unsigned b = 0, be = bits.size(); b < be; b++) {
|
|
if (bits[b]) {
|
|
unsigned t = tensor(b);
|
|
unsigned i = index(b);
|
|
llvm::dbgs() << " i_" << t << "_" << i << "_";
|
|
switch (dims[t][i]) {
|
|
case kSparse:
|
|
llvm::dbgs() << "S";
|
|
break;
|
|
case kDense:
|
|
llvm::dbgs() << "D";
|
|
break;
|
|
case kUndef:
|
|
llvm::dbgs() << "U";
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif // NDEBUG
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Builder methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|
Kind kind = tensorExps[e].kind;
|
|
switch (kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
case kInvariant:
|
|
case kIndex: {
|
|
// Either the index is really used in the tensor expression, or it is
|
|
// set to the undefined index in that dimension. An invariant expression,
|
|
// a proper index value, and a truly dynamic sparse output tensor are set
|
|
// to a synthetic tensor with undefined indices only to ensure the
|
|
// iteration space is not skipped as a result of their contents.
|
|
unsigned s = addSet();
|
|
unsigned t = syntheticTensor;
|
|
if (kind == kTensor) {
|
|
t = tensorExps[e].tensor;
|
|
if (hasSparseOut && t == outTensor)
|
|
t = syntheticTensor;
|
|
}
|
|
latSets[s].push_back(addLat(t, i, e));
|
|
return s;
|
|
}
|
|
// Unary operations.
|
|
case kAbsF:
|
|
case kAbsC:
|
|
case kCeilF:
|
|
case kFloorF:
|
|
case kSqrtF:
|
|
case kSqrtC:
|
|
case kExpm1F:
|
|
case kExpm1C:
|
|
case kLog1pF:
|
|
case kLog1pC:
|
|
case kSinF:
|
|
case kSinC:
|
|
case kTanhF:
|
|
case kTanhC:
|
|
case kNegF:
|
|
case kNegC:
|
|
case kNegI:
|
|
case kTruncF:
|
|
case kExtF:
|
|
case kCastFS:
|
|
case kCastFU:
|
|
case kCastSF:
|
|
case kCastUF:
|
|
case kCastS:
|
|
case kCastU:
|
|
case kCastIdx:
|
|
case kTruncI:
|
|
case kCIm:
|
|
case kCRe:
|
|
case kBitCast:
|
|
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
|
|
// lattice set of the operand through the operator into a new set.
|
|
//
|
|
// -y|!y | y |
|
|
// --+---+---+
|
|
// | 0 |-y |
|
|
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
|
|
tensorExps[e].val);
|
|
case kBinaryBranch:
|
|
// The left or right half of a binary operation which has already
|
|
// been split into separate operations for each region.
|
|
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
|
|
tensorExps[e].op);
|
|
case kUnary:
|
|
// A custom unary operation.
|
|
//
|
|
// op y| !y | y |
|
|
// ----+----------+------------+
|
|
// | absent() | present(y) |
|
|
{
|
|
unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
|
|
UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
|
|
Region &absentRegion = unop.getAbsentRegion();
|
|
|
|
if (absentRegion.empty()) {
|
|
// Simple mapping over existing values.
|
|
return mapSet(kind, child0, Value(), unop);
|
|
} // Use a disjunction with `unop` on the left and the absent value as an
|
|
// invariant on the right.
|
|
Block &absentBlock = absentRegion.front();
|
|
YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
|
|
Value absentVal = absentYield.getResult();
|
|
unsigned rhs = addExp(kInvariant, absentVal);
|
|
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
|
|
}
|
|
// Binary operations.
|
|
case kMulF:
|
|
case kMulC:
|
|
case kMulI:
|
|
case kAndI:
|
|
// A multiplicative operation only needs to be performed
|
|
// for the conjunction of sparse iteration spaces.
|
|
//
|
|
// x*y|!y | y |
|
|
// ---+---+---+
|
|
// !x | 0 | 0 |
|
|
// x | 0 |x*y|
|
|
//
|
|
// Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
|
|
return takeConj(kind, // take binary conjunction
|
|
buildLattices(tensorExps[e].children.e0, i),
|
|
buildLattices(tensorExps[e].children.e1, i));
|
|
case kDivF:
|
|
case kDivC:
|
|
case kDivS:
|
|
case kDivU:
|
|
// A division is tricky, since 0/0, 0/c, c/0 all have
|
|
// specific outcomes for floating-point and integers.
|
|
// Thus, we need to traverse the full iteration space.
|
|
//
|
|
// x/y|!y | y |
|
|
// ---+---+---+
|
|
// !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
|
|
// x |x/0|x/y| INT: x/0=exception for any x
|
|
//
|
|
// TODO: for now we "fixed" this by only accepting x/c cases
|
|
// during expression building, so that the conjunction
|
|
// rules applies (viz. x/c = x*(1/c) as far as lattice
|
|
// construction is concerned).
|
|
assert(!maybeZero(tensorExps[e].children.e1));
|
|
return takeConj(kind, // take binary conjunction
|
|
buildLattices(tensorExps[e].children.e0, i),
|
|
buildLattices(tensorExps[e].children.e1, i));
|
|
case kAddF:
|
|
case kAddC:
|
|
case kAddI:
|
|
case kSubF:
|
|
case kSubC:
|
|
case kSubI:
|
|
case kOrI:
|
|
case kXorI:
|
|
// An additive operation needs to be performed
|
|
// for the disjunction of sparse iteration spaces.
|
|
//
|
|
// x+y|!y | y | x-y|!y | y |
|
|
// ---+---+---+ ---+---+---+
|
|
// !x | 0 | y | !x | 0 |-y |
|
|
// x | x |x+y| x | x |x-y|
|
|
return takeDisj(kind, // take binary disjunction
|
|
buildLattices(tensorExps[e].children.e0, i),
|
|
buildLattices(tensorExps[e].children.e1, i));
|
|
case kShrS:
|
|
case kShrU:
|
|
case kShlI:
|
|
// A shift operation by an invariant amount (viz. tensor expressions
|
|
// can only occur at the left-hand-side of the operator) can be handled
|
|
// with the conjuction rule.
|
|
assert(isInvariant(tensorExps[e].children.e1));
|
|
return takeConj(kind, // take binary conjunction
|
|
buildLattices(tensorExps[e].children.e0, i),
|
|
buildLattices(tensorExps[e].children.e1, i));
|
|
case kBinary:
|
|
// A custom binary operation.
|
|
//
|
|
// x op y| !y | y |
|
|
// ------+---------+--------------+
|
|
// !x | empty | right(y) |
|
|
// x | left(x) | overlap(x,y) |
|
|
{
|
|
unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
|
|
unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
|
|
BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
|
|
Region &leftRegion = binop.getLeftRegion();
|
|
Region &rightRegion = binop.getRightRegion();
|
|
// Left Region.
|
|
Operation *leftYield = nullptr;
|
|
if (!leftRegion.empty()) {
|
|
Block &leftBlock = leftRegion.front();
|
|
leftYield = leftBlock.getTerminator();
|
|
}
|
|
// Right Region.
|
|
Operation *rightYield = nullptr;
|
|
if (!rightRegion.empty()) {
|
|
Block &rightBlock = rightRegion.front();
|
|
rightYield = rightBlock.getTerminator();
|
|
}
|
|
bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
|
|
bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
|
|
return takeCombi(kBinary, child0, child1, binop, includeLeft,
|
|
kBinaryBranch, leftYield, includeRight, kBinaryBranch,
|
|
rightYield);
|
|
}
|
|
}
|
|
llvm_unreachable("unexpected expression kind");
|
|
}
|
|
|
|
Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
|
// Build the linalg semantics backward from yield.
|
|
Operation *yield = op.getRegion().front().getTerminator();
|
|
assert(isa<linalg::YieldOp>(yield));
|
|
return buildTensorExp(op, yield->getOperand(0));
|
|
}
|
|
|
|
/// Only returns false if we are certain this is a nonzero.
|
|
bool Merger::maybeZero(unsigned e) const {
|
|
if (tensorExps[e].kind == kInvariant) {
|
|
if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
|
|
ArrayAttr arrayAttr = c.getValue();
|
|
return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
|
|
arrayAttr[1].cast<FloatAttr>().getValue().isZero();
|
|
}
|
|
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
|
|
return c.value() == 0;
|
|
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
|
|
return c.value().isZero();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool Merger::isInvariant(unsigned e) const {
|
|
return tensorExps[e].kind == kInvariant;
|
|
}
|
|
|
|
Type Merger::inferType(unsigned e, Value src) {
|
|
// Obtain the destination type from the cast node.
|
|
Type dtp = tensorExps[e].val.getType();
|
|
// Inspect source type. For vector types, apply the same
|
|
// vectorization to the destination type.
|
|
if (auto vtp = src.getType().dyn_cast<VectorType>())
|
|
return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
|
|
return dtp;
|
|
}
|
|
|
|
/// Ensures that sparse compiler can generate code for expression.
|
|
static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
|
|
// Arguments are always admissable.
|
|
if (auto arg = v.dyn_cast<BlockArgument>())
|
|
return true;
|
|
// Accept index anywhere.
|
|
Operation *def = v.getDefiningOp();
|
|
if (isa<linalg::IndexOp>(def))
|
|
return true;
|
|
// Operation defined outside branch.
|
|
if (def->getBlock() != block) {
|
|
return def->getBlock() != op->getBlock(); // invariant?
|
|
}
|
|
// Operation defined within branch. Anything is accepted,
|
|
// as long as all subexpressions are admissable.
|
|
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
|
|
if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
/// Ensures that sparse compiler can generate code for branch.
|
|
static bool isAdmissableBranch(Operation *op, Region ®ion) {
|
|
if (region.empty())
|
|
return true;
|
|
// Build the semi-ring branch semantics backward from yield.
|
|
Operation *yield = region.front().getTerminator();
|
|
assert(isa<YieldOp>(yield));
|
|
return isAdmissableBranchExp(op, ®ion.front(), yield->getOperand(0));
|
|
}
|
|
|
|
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
|
unsigned argN = arg.getArgNumber();
|
|
// Any argument of the generic op that is not marked as a scalar
|
|
// argument is considered a tensor, indexed by the implicit loop
|
|
// bounds. This includes rank-0 tensor arguments.
|
|
if (arg.getOwner()->getParentOp() == op) {
|
|
OpOperand *t = op.getInputAndOutputOperands()[argN];
|
|
if (!op.isScalar(t))
|
|
return addExp(kTensor, argN);
|
|
v = t->get(); // get scalar value
|
|
}
|
|
// Any other argument (marked as scalar argument for the generic op
|
|
// or belonging to an enveloping op) is considered invariant.
|
|
return addExp(kInvariant, v);
|
|
}
|
|
// Something defined outside is invariant.
|
|
Operation *def = v.getDefiningOp();
|
|
if (def->getBlock() != &op.getRegion().front())
|
|
return addExp(kInvariant, v);
|
|
// Construct index operations.
|
|
if (def->getNumOperands() == 0) {
|
|
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
|
|
return addExp(kIndex, indexOp.getDim());
|
|
}
|
|
// Construct unary operations if subexpression can be built.
|
|
if (def->getNumOperands() == 1) {
|
|
auto x = buildTensorExp(op, def->getOperand(0));
|
|
if (x.has_value()) {
|
|
unsigned e = x.value();
|
|
if (isa<math::AbsFOp>(def))
|
|
return addExp(kAbsF, e);
|
|
if (isa<complex::AbsOp>(def))
|
|
return addExp(kAbsC, e);
|
|
if (isa<math::CeilOp>(def))
|
|
return addExp(kCeilF, e);
|
|
if (isa<math::FloorOp>(def))
|
|
return addExp(kFloorF, e);
|
|
if (isa<math::SqrtOp>(def))
|
|
return addExp(kSqrtF, e);
|
|
if (isa<complex::SqrtOp>(def))
|
|
return addExp(kSqrtC, e);
|
|
if (isa<math::ExpM1Op>(def))
|
|
return addExp(kExpm1F, e);
|
|
if (isa<complex::Expm1Op>(def))
|
|
return addExp(kExpm1C, e);
|
|
if (isa<math::Log1pOp>(def))
|
|
return addExp(kLog1pF, e);
|
|
if (isa<complex::Log1pOp>(def))
|
|
return addExp(kLog1pC, e);
|
|
if (isa<math::SinOp>(def))
|
|
return addExp(kSinF, e);
|
|
if (isa<complex::SinOp>(def))
|
|
return addExp(kSinC, e);
|
|
if (isa<math::TanhOp>(def))
|
|
return addExp(kTanhF, e);
|
|
if (isa<complex::TanhOp>(def))
|
|
return addExp(kTanhC, e);
|
|
if (isa<arith::NegFOp>(def))
|
|
return addExp(kNegF, e); // no negi in std
|
|
if (isa<complex::NegOp>(def))
|
|
return addExp(kNegC, e);
|
|
if (isa<arith::TruncFOp>(def))
|
|
return addExp(kTruncF, e, v);
|
|
if (isa<arith::ExtFOp>(def))
|
|
return addExp(kExtF, e, v);
|
|
if (isa<arith::FPToSIOp>(def))
|
|
return addExp(kCastFS, e, v);
|
|
if (isa<arith::FPToUIOp>(def))
|
|
return addExp(kCastFU, e, v);
|
|
if (isa<arith::SIToFPOp>(def))
|
|
return addExp(kCastSF, e, v);
|
|
if (isa<arith::UIToFPOp>(def))
|
|
return addExp(kCastUF, e, v);
|
|
if (isa<arith::ExtSIOp>(def))
|
|
return addExp(kCastS, e, v);
|
|
if (isa<arith::ExtUIOp>(def))
|
|
return addExp(kCastU, e, v);
|
|
if (isa<arith::IndexCastOp>(def))
|
|
return addExp(kCastIdx, e, v);
|
|
if (isa<arith::TruncIOp>(def))
|
|
return addExp(kTruncI, e, v);
|
|
if (isa<complex::ImOp>(def))
|
|
return addExp(kCIm, e);
|
|
if (isa<complex::ReOp>(def))
|
|
return addExp(kCRe, e);
|
|
if (isa<arith::BitcastOp>(def))
|
|
return addExp(kBitCast, e, v);
|
|
if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
|
|
if (isAdmissableBranch(unop, unop.getPresentRegion()) &&
|
|
isAdmissableBranch(unop, unop.getAbsentRegion()))
|
|
return addExp(kUnary, e, Value(), def);
|
|
}
|
|
}
|
|
}
|
|
// Construct binary operations if subexpressions can be built.
|
|
// See buildLattices() for an explanation of rejecting certain
|
|
// division and shift operations
|
|
if (def->getNumOperands() == 2) {
|
|
auto x = buildTensorExp(op, def->getOperand(0));
|
|
auto y = buildTensorExp(op, def->getOperand(1));
|
|
if (x.has_value() && y.has_value()) {
|
|
unsigned e0 = x.value();
|
|
unsigned e1 = y.value();
|
|
if (isa<arith::MulFOp>(def))
|
|
return addExp(kMulF, e0, e1);
|
|
if (isa<complex::MulOp>(def))
|
|
return addExp(kMulC, e0, e1);
|
|
if (isa<arith::MulIOp>(def))
|
|
return addExp(kMulI, e0, e1);
|
|
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
|
|
return addExp(kDivF, e0, e1);
|
|
if (isa<complex::DivOp>(def) && !maybeZero(e1))
|
|
return addExp(kDivC, e0, e1);
|
|
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
|
|
return addExp(kDivS, e0, e1);
|
|
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
|
|
return addExp(kDivU, e0, e1);
|
|
if (isa<arith::AddFOp>(def))
|
|
return addExp(kAddF, e0, e1);
|
|
if (isa<complex::AddOp>(def))
|
|
return addExp(kAddC, e0, e1);
|
|
if (isa<arith::AddIOp>(def))
|
|
return addExp(kAddI, e0, e1);
|
|
if (isa<arith::SubFOp>(def))
|
|
return addExp(kSubF, e0, e1);
|
|
if (isa<complex::SubOp>(def))
|
|
return addExp(kSubC, e0, e1);
|
|
if (isa<arith::SubIOp>(def))
|
|
return addExp(kSubI, e0, e1);
|
|
if (isa<arith::AndIOp>(def))
|
|
return addExp(kAndI, e0, e1);
|
|
if (isa<arith::OrIOp>(def))
|
|
return addExp(kOrI, e0, e1);
|
|
if (isa<arith::XOrIOp>(def))
|
|
return addExp(kXorI, e0, e1);
|
|
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
|
|
return addExp(kShrS, e0, e1);
|
|
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
|
|
return addExp(kShrU, e0, e1);
|
|
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
|
|
return addExp(kShlI, e0, e1);
|
|
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
|
|
if (isAdmissableBranch(binop, binop.getOverlapRegion()) &&
|
|
(binop.getLeftIdentity() ||
|
|
isAdmissableBranch(binop, binop.getLeftRegion())) &&
|
|
(binop.getRightIdentity() ||
|
|
isAdmissableBranch(binop, binop.getRightRegion())))
|
|
return addExp(kBinary, e0, e1, Value(), def);
|
|
}
|
|
}
|
|
}
|
|
// Cannot build.
|
|
return None;
|
|
}
|
|
|
|
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
|
|
ValueRange vals) {
|
|
// Make a clone of overlap region.
|
|
Region tmpRegion;
|
|
BlockAndValueMapping mapper;
|
|
region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
|
|
Block &clonedBlock = tmpRegion.front();
|
|
YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
|
|
// Merge cloned block and return yield value.
|
|
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
|
|
Value val = clonedYield.getResult();
|
|
rewriter.eraseOp(clonedYield);
|
|
rewriter.eraseOp(placeholder);
|
|
return val;
|
|
}
|
|
|
|
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
|
|
Operation *op, Value v0) {
|
|
if (!v0)
|
|
// Empty input value must be propagated.
|
|
return Value();
|
|
UnaryOp unop = cast<UnaryOp>(op);
|
|
Region &presentRegion = unop.getPresentRegion();
|
|
if (presentRegion.empty())
|
|
// Uninitialized Value() will be interpreted as missing data in the
|
|
// output.
|
|
return Value();
|
|
return insertYieldOp(rewriter, loc, presentRegion, {v0});
|
|
}
|
|
|
|
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
|
|
Operation *op, Value v0, Value v1) {
|
|
if (!v0 || !v1)
|
|
// Empty input values must be propagated.
|
|
return Value();
|
|
BinaryOp binop = cast<BinaryOp>(op);
|
|
Region &overlapRegion = binop.getOverlapRegion();
|
|
if (overlapRegion.empty())
|
|
// Uninitialized Value() will be interpreted as missing data in the
|
|
// output.
|
|
return Value();
|
|
return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
|
|
}
|
|
|
|
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|
Value v0, Value v1) {
|
|
switch (tensorExps[e].kind) {
|
|
// Leaf.
|
|
case kTensor:
|
|
case kInvariant:
|
|
case kIndex:
|
|
llvm_unreachable("unexpected non-op");
|
|
// Unary operations.
|
|
case kAbsF:
|
|
return rewriter.create<math::AbsFOp>(loc, v0);
|
|
case kAbsC: {
|
|
auto type = v0.getType().cast<ComplexType>();
|
|
auto eltType = type.getElementType().cast<FloatType>();
|
|
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
|
|
}
|
|
case kCeilF:
|
|
return rewriter.create<math::CeilOp>(loc, v0);
|
|
case kFloorF:
|
|
return rewriter.create<math::FloorOp>(loc, v0);
|
|
case kSqrtF:
|
|
return rewriter.create<math::SqrtOp>(loc, v0);
|
|
case kSqrtC:
|
|
return rewriter.create<complex::SqrtOp>(loc, v0);
|
|
case kExpm1F:
|
|
return rewriter.create<math::ExpM1Op>(loc, v0);
|
|
case kExpm1C:
|
|
return rewriter.create<complex::Expm1Op>(loc, v0);
|
|
case kLog1pF:
|
|
return rewriter.create<math::Log1pOp>(loc, v0);
|
|
case kLog1pC:
|
|
return rewriter.create<complex::Log1pOp>(loc, v0);
|
|
case kSinF:
|
|
return rewriter.create<math::SinOp>(loc, v0);
|
|
case kSinC:
|
|
return rewriter.create<complex::SinOp>(loc, v0);
|
|
case kTanhF:
|
|
return rewriter.create<math::TanhOp>(loc, v0);
|
|
case kTanhC:
|
|
return rewriter.create<complex::TanhOp>(loc, v0);
|
|
case kNegF:
|
|
return rewriter.create<arith::NegFOp>(loc, v0);
|
|
case kNegC:
|
|
return rewriter.create<complex::NegOp>(loc, v0);
|
|
case kNegI: // no negi in std
|
|
return rewriter.create<arith::SubIOp>(
|
|
loc,
|
|
rewriter.create<arith::ConstantOp>(loc, v0.getType(),
|
|
rewriter.getZeroAttr(v0.getType())),
|
|
v0);
|
|
case kTruncF:
|
|
return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
|
|
case kExtF:
|
|
return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
|
|
case kCastFS:
|
|
return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
|
|
case kCastFU:
|
|
return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
|
|
case kCastSF:
|
|
return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
|
|
case kCastUF:
|
|
return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
|
|
case kCastS:
|
|
return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
|
|
case kCastU:
|
|
return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
|
|
case kCastIdx:
|
|
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
|
|
case kTruncI:
|
|
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
|
|
case kCIm: {
|
|
auto type = v0.getType().cast<ComplexType>();
|
|
auto eltType = type.getElementType().cast<FloatType>();
|
|
return rewriter.create<complex::ImOp>(loc, eltType, v0);
|
|
}
|
|
case kCRe: {
|
|
auto type = v0.getType().cast<ComplexType>();
|
|
auto eltType = type.getElementType().cast<FloatType>();
|
|
return rewriter.create<complex::ReOp>(loc, eltType, v0);
|
|
}
|
|
case kBitCast:
|
|
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
|
|
// Binary operations.
|
|
case kMulF:
|
|
return rewriter.create<arith::MulFOp>(loc, v0, v1);
|
|
case kMulC:
|
|
return rewriter.create<complex::MulOp>(loc, v0, v1);
|
|
case kMulI:
|
|
return rewriter.create<arith::MulIOp>(loc, v0, v1);
|
|
case kDivF:
|
|
return rewriter.create<arith::DivFOp>(loc, v0, v1);
|
|
case kDivC:
|
|
return rewriter.create<complex::DivOp>(loc, v0, v1);
|
|
case kDivS:
|
|
return rewriter.create<arith::DivSIOp>(loc, v0, v1);
|
|
case kDivU:
|
|
return rewriter.create<arith::DivUIOp>(loc, v0, v1);
|
|
case kAddF:
|
|
return rewriter.create<arith::AddFOp>(loc, v0, v1);
|
|
case kAddC:
|
|
return rewriter.create<complex::AddOp>(loc, v0, v1);
|
|
case kAddI:
|
|
return rewriter.create<arith::AddIOp>(loc, v0, v1);
|
|
case kSubF:
|
|
return rewriter.create<arith::SubFOp>(loc, v0, v1);
|
|
case kSubC:
|
|
return rewriter.create<complex::SubOp>(loc, v0, v1);
|
|
case kSubI:
|
|
return rewriter.create<arith::SubIOp>(loc, v0, v1);
|
|
case kAndI:
|
|
return rewriter.create<arith::AndIOp>(loc, v0, v1);
|
|
case kOrI:
|
|
return rewriter.create<arith::OrIOp>(loc, v0, v1);
|
|
case kXorI:
|
|
return rewriter.create<arith::XOrIOp>(loc, v0, v1);
|
|
case kShrS:
|
|
return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
|
|
case kShrU:
|
|
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
|
|
case kShlI:
|
|
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
|
|
case kBinaryBranch: // semi-ring ops with custom logic.
|
|
return insertYieldOp(rewriter, loc,
|
|
*tensorExps[e].op->getBlock()->getParent(), {v0});
|
|
case kUnary:
|
|
return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
|
|
case kBinary:
|
|
return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
|
|
}
|
|
llvm_unreachable("unexpected expression kind in build");
|
|
}
|
|
|
|
} // namespace sparse_tensor
|
|
} // namespace mlir
|