MultiheadAttention¶
- 
class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)[source]¶
- Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need - where . - Parameters
- embed_dim – total dimension of the model. 
- num_heads – parallel attention heads. 
- dropout – a Dropout layer on attn_output_weights. Default: 0.0. 
- bias – add bias as module parameter. Default: True. 
- add_bias_kv – add bias to the key and value sequences at dim=0. 
- add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1. 
- kdim – total number of features in key. Default: None. 
- vdim – total number of features in value. Default: None. 
 
 - Note that if - kdimand- vdimare None, they will be set to- embed_dimsuch that query, key, and value have the same number of features.- Examples: - >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) - 
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]¶
- Parameters
- key, value (query,) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details. 
- key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored 
- need_weights – output attn_output_weights. 
- attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. 
 
 - Shapes for inputs:
- query: where L is the target sequence length, N is the batch size, E is the embedding dimension. 
- key: , where S is the source sequence length, N is the batch size, E is the embedding dimension. 
- value: where S is the source sequence length, N is the batch size, E is the embedding dimension. 
- key_padding_mask: where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of - Truewill be ignored while the position with the value of- Falsewill be unchanged.
- attn_mask: if a 2D mask: where L is the target sequence length, S is the source sequence length. - If a 3D mask: where N is the batch size, L is the target sequence length, S is the source sequence length. - attn_maskensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with- Trueis not allowed to attend while- Falsevalues will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
 
- Shapes for outputs:
- attn_output: where L is the target sequence length, N is the batch size, E is the embedding dimension. 
- attn_output_weights: where N is the batch size, L is the target sequence length, S is the source sequence length.