약간 디테일한 부분은 본 페이지의 코드 내에 주석 참고 요망.

class PositionalEncoding(nn.Module):
	def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)

    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

	def forward(self, x):
	  x = x + self.pe[:x.size(0)]
	  return self.dropout(x)
class EncoderMel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
				#임베딩
        self.enc_emb = nn.Linear(128, self.config.d_hidn)
				#포지셔널
        self.pos_emb = PositionalEncoding(self.config.d_hidn, 0)
        self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer)])
    
    def forward(self, inputs):
				# 참고한 코드에서 math.sqrt(128)를 곱해주어 별 의미없이 따라해줌.
        x = self.enc_emb(inputs) * math.sqrt(128)
				# 임베딩한 값 + 포지셔널한 값
        x = x + self.pos_emb(x)
				....
class EncoderKeysTar(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config

	# 트랜스포머 출력이 input/output이 동일한데, 앞쪽 인코더에서 각각의 값을 이어붙여 2배의 값이 인풋으로 들어와서 그냥 Linear 임베딩으로 반으로 줄여줌.
    self.enc_emb = nn.Linear(self.config.d_hidn*2, self.config.d_hidn)
    self.pos_emb = PositionalEncoding(self.config.d_hidn, 0)
    self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer_multimodal)])

def forward(self, inputs):

    x = self.enc_emb(inputs) * math.sqrt(128)
    x = x + self.pos_emb(x)

- Transformer 전체 구조

```python
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder_mel = EncoderMel(self.config)
        self.encoder_keys = EncoderKeys(self.config)
        self.encoder_keys_tar = EncoderKeysTar(self.config)
    
    def forward(self, enc_mel_inputs, enc_key_inputs):
        enc_mel_outputs, enc_mel_self_attn_probs = self.encoder_mel(enc_mel_inputs)
        enc_keys_outputs, enc_keys_self_attn_probs = self.encoder_keys(enc_key_inputs)
        # dim = 0이라 batch_size가 합쳐졌는데, dim=1은 시퀀스니 dim=2로 합쳐줌.
        enc_mel_keys = torch.cat([enc_mel_outputs, enc_keys_outputs], dim=2)
        enc_keys_tar_outputs, enc_keys_tar_self_attn_probs = self.encoder_keys_tar(enc_mel_keys)
    
        return enc_keys_tar_outputs, enc_keys_tar_self_attn_probs
class KeypointsPred(nn.Module):
	self.config = config
	self.transformer = Transformer(self.config)
	self.projection = nn.Linear(self.config.d_hidn, self.config.n_enc_keys_vocab, bias=False)
    
	def forward(self, enc_inputs, dec_inputs):
    dec_outputs, enc_self_attn_probs = self.transformer(enc_inputs, dec_inputs)
# 주석처리 dec_outputs, _ = torch.max(dec_outputs, dim=1)
    logits = self.projection(dec_outputs)
	  return logits, enc_self_attn_probs