https://github.com/lucidrains/enformer-pytorch

모델 구조

https://raw.githubusercontent.com/lucidrains/enformer-pytorch/0201fd8ba37116ec9a1caf1131d8af918f084ff9/enformer.png

프로젝트 Scaffold 만들기

https://github.com/lucidrains/enformer-pytorch/commit/0201fd8ba37116ec9a1caf1131d8af918f084ff9처

첫 class 구현

https://github.com/lucidrains/enformer-pytorch/commit/0c74ba5f5c49ea83a1ff6b1acd63443ccedf2ae2

AttentionPool 구현

class AttentionPool(nn.Module):
    def __init__(self, dim, pool_size = 2):
        super().__init__()
				# b=batch size, d=channel, n=L/2, p=2
        self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
        self.to_attn_logits = nn.Parameter(torch.eye(dim))

    def forward(self, x):
        attn_logits = einsum('b d n, d e -> b e n', x, self.to_attn_logits)
        x = self.pool_fn(x)
        attn = self.pool_fn(attn_logits).softmax(dim = -1)
        return (x * attn).sum(dim = -1)

Untitled

최종적으로는 AttentionPool이 이렇게 구현된다

class AttentionPool(nn.Module):
    def __init__(self, dim, pool_size = 2):
        super().__init__()
        self.pool_size = pool_size
        self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
        self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
				# in_channels=dim, out_channels=dim, kernel_size=1, stride=1, bias=False
				# dim -> dim으로 가는 변환. dim=C 라고 보면 된다.
				# Conv2d weight initialize가 identity matrix로 되던가?

    def forward(self, x):
        b, _, n = x.shape
        remainder = n % self.pool_size
        needs_padding = remainder > 0

        if needs_padding:
            x = F.pad(x, (0, remainder), value = 0)
            mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
            mask = F.pad(mask, (0, remainder), value = True)

        x = self.pool_fn(x)
        logits = self.to_attn_logits(x)

        if needs_padding:
            mask_value = -torch.finfo(logits.dtype).max
            logits = logits.masked_fill(self.pool_fn(mask), mask_value)

        attn = logits.softmax(dim = -1)

        return (x * attn).sum(dim = -1)

Residual class 구현

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x