728x90
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

 

create_mask 함수는 기계 번역 모델에서 사용되는 마스크를 생성하는 함수이다. 이 함수는 소스(입력) 시퀀스와 타겟(출력) 시퀀스에 대한 다양한 마스크를 생성하여 모델이 올바른 예측을 수행할 수 있도록 도와준다.

 

create_mask 함수의 기능은 다음과 같다.

 

1. 소스 시퀀스와 타겟 시퀀스의 길이를 확인한다.

2. generate_square_subsequent_mask 함수를 사용하여 자기 회귀적인 마스크를 생성한다. 이 마스크는 타겟 시퀀스의 각 위치 이후의 토큰들을 가린다. 이는 모델이 미래의 정보를 사용하지 않도록 제한하는 역할을 한다.

*자기 회귀(autoregressive): 변수의 과거 값에 선형적으로 의존하는 것. 시퀀스나 시계열 데이터를 처리할 때 이전 시점의 출력을 현재 시점의 입력으로 사용하는 모델 구조를 말한다.

 

3. 소스 시퀀스의 패딩 마스크와 타겟 시퀀스의 패딩 마스크를 생성한다. 패딩 마스크는 각 시퀀스에서 실제 토큰과 패딩 토큰을 구분하는 역할을 한다. 모델은 패딩된 부분을 무시하고 실제 토큰에만 집중하여 예측을 수행한다.

4. 생성된 마스크들을 반환한다. 이러한 마스크들은 모델의 학습과 추론 과정에서 사용되어 올바른 예측과 정확한 어텐션을 가능하게 한다.

 

마스크의 생성과 사용은 자연어 처리에서 모델이 효과적으로 작동할 수 있도록 돕는 중요한 기술이다. 모델이 입력과 출력 시퀀스를 올바르게 처리할 수 있도록 필요한 마스크를 생성하는 역할을 한다.

728x90

+ Recent posts