약간 디테일한 부분은 본 페이지의 코드 내에 주석 참고 요망.
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
모든 트랜스포머 인코더에서 마스킹 제거.
트랜스포머 인코더 구현
class EncoderMel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
#임베딩
self.enc_emb = nn.Linear(128, self.config.d_hidn)
#포지셔널
self.pos_emb = PositionalEncoding(self.config.d_hidn, 0)
self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer)])
def forward(self, inputs):
# 참고한 코드에서 math.sqrt(128)를 곱해주어 별 의미없이 따라해줌.
x = self.enc_emb(inputs) * math.sqrt(128)
# 임베딩한 값 + 포지셔널한 값
x = x + self.pos_emb(x)
....
class EncoderKeysTar(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 트랜스포머 출력이 input/output이 동일한데, 앞쪽 인코더에서 각각의 값을 이어붙여 2배의 값이 인풋으로 들어와서 그냥 Linear 임베딩으로 반으로 줄여줌.
self.enc_emb = nn.Linear(self.config.d_hidn*2, self.config.d_hidn)
self.pos_emb = PositionalEncoding(self.config.d_hidn, 0)
self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer_multimodal)])
def forward(self, inputs):
x = self.enc_emb(inputs) * math.sqrt(128)
x = x + self.pos_emb(x)
- Transformer 전체 구조
```python
class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder_mel = EncoderMel(self.config)
self.encoder_keys = EncoderKeys(self.config)
self.encoder_keys_tar = EncoderKeysTar(self.config)
def forward(self, enc_mel_inputs, enc_key_inputs):
enc_mel_outputs, enc_mel_self_attn_probs = self.encoder_mel(enc_mel_inputs)
enc_keys_outputs, enc_keys_self_attn_probs = self.encoder_keys(enc_key_inputs)
# dim = 0이라 batch_size가 합쳐졌는데, dim=1은 시퀀스니 dim=2로 합쳐줌.
enc_mel_keys = torch.cat([enc_mel_outputs, enc_keys_outputs], dim=2)
enc_keys_tar_outputs, enc_keys_tar_self_attn_probs = self.encoder_keys_tar(enc_mel_keys)
return enc_keys_tar_outputs, enc_keys_tar_self_attn_probs
class KeypointsPred(nn.Module):
self.config = config
self.transformer = Transformer(self.config)
self.projection = nn.Linear(self.config.d_hidn, self.config.n_enc_keys_vocab, bias=False)
def forward(self, enc_inputs, dec_inputs):
dec_outputs, enc_self_attn_probs = self.transformer(enc_inputs, dec_inputs)
# 주석처리 dec_outputs, _ = torch.max(dec_outputs, dim=1)
logits = self.projection(dec_outputs)
return logits, enc_self_attn_probs