pytorch实现seq2seq时对loss进行mask的方式

  

在Pytorch实现seq2seq模型中,对于一个batch中的每个序列,其长度可能不一致。对于长度不一致的序列,需要进行pad操作,使其长度一致。但是,在计算loss的时候,pad部分的贡献必须要被剔除,否则会带来噪声。

为了解决这一问题,可以使用mask技术,即使用一个mask张量对loss进行掩码,将pad部分设置为0,只计算有效部分的loss。

下面是实现seq2seq时对loss进行mask的方式的完整攻略:

1.创建mask张量

通过给定的输入序列长度,创建一个bool掩码,其中有效部分为True,pad部分为False。

def create_mask(seq_len, pad_idx):
    mask = (torch.ones(seq_len) * pad_idx).unsqueeze(0) != torch.arange(seq_len).unsqueeze(1)
    return mask.to(device)

其中,seq_len为每个序列的长度,pad_idx为pad的token索引,此处默认使用0进行pad。

2.计算loss时掩码

在计算loss时,将mask张量与计算得到的loss张量相乘即可实现mask。

mask = create_mask(target_seq_len, pad_idx)  # 创建mask张量
loss = criterion(output, target_seqs)  # 计算loss
loss = (loss * mask.float()).sum() / mask.sum()  # mask掩码

3.示例说明

下面给出两个示例,更好地理解如何使用mask对seq2seq模型的loss进行掩码。

假设我们有如下两个序列:

  • 输入序列:['I', 'love', 'you']
  • 目标序列:['Ich', 'liebe', 'dich']

其中,我们使用3个token来表示输入和输出序列,对应的pad_idx为0。那么,我们需要将输入和输出序列转换为相同的长度,这里设定为5。那么,经过pad之后,就可以得到如下矩阵:

# input_seq:['I', 'love', 'you']
input_seqs = [[1, 3, 2, 0, 0]]  # 0表示pad

# target_seq:['Ich', 'liebe', 'dich']
target_seqs = [[4, 5, 6, 2, 0]]  # 0表示pad

其中,1/3/2对应的是输入序列中的'I'/'love'/'you',4/5/6对应的是目标序列中的'Ich'/'liebe'/'dich'。

接下来,我们需要创建掩码张量,对于pad部分置为False,其他部分置为True。

pad_idx = 0
input_seq_len = 3  # 输入序列长度
target_seq_len = 3  # 目标序列长度
input_mask = create_mask(input_seq_len, pad_idx)
# input_mask: [[ True,  True,  True, False, False]]
target_mask = create_mask(target_seq_len, pad_idx) 
# target_mask: [[ True,  True,  True, False, False]]

最后,计算loss时,使用mask张量掩码:

output = model(input_seqs, input_mask, target_seqs[:, :-1], target_mask[:, :-1])
loss = criterion(output, target_seqs[:, 1:]) 
# 对验证集batch中每个序列的loss进行求和并求平均
loss = (loss * target_mask[:, 1:].float()).sum() / target_mask[:, 1:].sum()

这里,我们首先使用model计算模型输出,然后计算loss,最后使用target_mask掩码。需要注意的是,这里的target_seqs需要去掉最后的一个token,也就是'pad',以保证input_seqs和target_seqs的长度相同。

相关文章