Step1. 构建支配树

支配树定义

对一个节点A的支配节点S,其实就是A的所有input节点的LCA(最近公共祖先)。

支配点算法

TVM使用的算法:

$$ LCA(a,b,...) = LCA(LCA(a,b),...) $$

这里实际有更优的算法。

LCA(a,b)

TVM使用的也是最暴力的办法,两个节点向上跳,直至相遇。这里也有更优的算法

Step2. Fuse

OP分类

enum OpPatternKind {
  // Elementwise operation
  kElemWise = 0,
  // Broadcasting operator, can always map output axis to the input in order.
  // for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
  // Note that the axis need to be in order so transpose is not a bcast operator.
  kBroadcast = 1,
  // Injective operator, can always injectively map output axis to a single input axis.
  // All injective operator can still be safely fused to injective and reduction.
  kInjective = 2,
  // Communicative reduction operator.
  kCommReduce = 3,
  // Complex operation, can still fuse elemwise operations into its output.
  // but cannot chain another complex op
  kOutEWiseFusable = 4,
  // The pattern for tuple nodes. Can fuse into subsequent injective ops,
  // but treated specially
  kTuple = 7,
  // Opaque operation, cannot fuse anything.
  kOpaque = 8
};

下面举一些op例子:

kElemWise

RELAY_REGISTER_UNARY_OP("log")
    .describe(R"code(Returns the log input array, computed element-wise.

)code" TVM_ADD_FILELINE)
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));

RELAY_REGISTER_UNARY_OP("log2")
    .describe(R"code(Returns the log to base 2 of input array, computed element-wise.

)code" TVM_ADD_FILELINE)
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));

RELAY_REGISTER_UNARY_OP("log10")
    .describe(R"code(Returns the log to base 10 of input array, computed element-wise.

)code" TVM_ADD_FILELINE)
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));

RELAY_REGISTER_UNARY_OP("tan")
    .describe(R"code(Returns the tan of input array, computed element-wise.

)code" TVM_ADD_FILELINE)
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));

kBroadcast

// Addition
RELAY_REGISTER_BINARY_OP("add")
    .describe("Elementwise add with broadcasting")
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));

// Subtraction
RELAY_REGISTER_BINARY_OP("subtract")
    .describe("Elementwise substract with broadcasting")
    .set_support_level(1)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));

// Right shift
RELAY_REGISTER_BINARY_OP("right_shift")
    .describe("Elementwise right shift with broadcasting")
    .set_support_level(4)
    .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));

其实上面的这些kBroadcast也是Elementwise的

kInjective