| import math |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class AttentionScore(nn.Module): |
| r""" |
| A helper class for attention operations. |
| There are no parameters in this module. |
| This module computes the alignment score with mask |
| and return only the attention score. |
| |
| The default operation is |
| |
| .. math:: |
| \pmb{u} = \mathrm{Attention}(q,\pmb{k}, \mathrm{mask}) |
| |
| where for each key :math:`k_j`, we have |
| |
| .. math:: |
| u_j = |
| \begin{cases} |
| &\frac{q^Tk_j}{\sqrt{\smash{d_q}}} & \text{ if } j \notin \mathrm{mask}\\ |
| &-\infty & \text{ otherwise. } |
| \end{cases} |
| |
| If ``use_tanh`` is ``True``, apply clipping on the logits :math:`u_j` before masking: |
| |
| .. math:: |
| u_j = |
| \begin{cases} |
| &C\mathrm{tanh}\left(\frac{q^Tk_j}{\sqrt{\smash{d_q}}}\right) & \text{ if } j \notin \mathrm{mask}\\ |
| &-\infty & \text{ otherwise. } |
| \end{cases} |
| |
| Args: |
| use_tanh: if True, use clipping on the logits |
| C: the range of the clipping [-C,C] |
| Inputs: query, keys, mask |
| * **query** : [..., 1, h_dim] |
| * **keys**: [..., graph_size, h_dim] |
| * **mask**: [..., graph_size] ``logits[...,j]==-inf`` if ``mask[...,j]==True``. |
| Outputs: logits |
| * **logits**: [..., 1, graph_size] The attention score for each key. |
| """ |
|
|
| def __init__(self, use_tanh=False, C=10): |
| super(AttentionScore, self).__init__() |
| self.use_tanh = use_tanh |
| self.C = C |
|
|
| def forward(self, query, key, mask=torch.zeros([], dtype=torch.bool)): |
| u = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) |
| if self.use_tanh: |
| logits = torch.tanh(u) * self.C |
| else: |
| logits = u |
|
|
| logits[mask.expand_as(logits)] = float("-inf") |
| return logits |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| r""" |
| Compute the multi-head attention. |
| |
| .. math:: |
| q^\prime = \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask}) |
| |
| The following is computed: |
| |
| .. math:: |
| \begin{aligned} |
| \pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\ |
| h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\ |
| q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right] |
| \end{aligned} |
| |
| Args: |
| embedding_dim: dimension of the query, keys, values |
| n_head: number of heads |
| Inputs: query, keys, value, mask |
| * **query** : [batch, n_querys, embedding_dim] |
| * **keys**: [batch, n_keys, embedding_dim] |
| * **value**: [batch, n_keys, embedding_dim] |
| * **mask**: [batch, 1, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch, 0, j]==True`` |
| Outputs: logits, out |
| * **out**: [batch, 1, embedding_dim] The output of the multi-head attention |
| """ |
|
|
| def __init__(self, embedding_dim, n_heads=8): |
| super(MultiHeadAttention, self).__init__() |
| self.n_heads = n_heads |
| self.attentionScore = AttentionScore() |
| self.project_out = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
|
| def forward(self, query, key, value, mask): |
| query_heads = self._make_heads(query) |
| key_heads = self._make_heads(key) |
| value_heads = self._make_heads(value) |
|
|
| |
| compatibility = self.attentionScore(query_heads, key_heads, mask) |
|
|
| |
| out_heads = torch.matmul(torch.softmax(compatibility, dim=-1), value_heads) |
|
|
| |
| out = self.project_out(self._unmake_heads(out_heads)) |
| return out |
|
|
| def _make_heads(self, v): |
| batch_size, nkeys, h_dim = v.shape |
| |
| out = v.reshape(batch_size, nkeys, self.n_heads, h_dim // self.n_heads).movedim(-2, 0) |
| return out |
|
|
| def _unmake_heads(self, v): |
| |
| out = v.movedim(0, -2).flatten(-2) |
| return out |
|
|
|
|
| class MultiHeadAttentionProj(nn.Module): |
| r""" |
| Compute the multi-head attention with projection. |
| Different from :class:`.MultiHeadAttention` which accepts precomputed query, keys, and values, |
| this module computes linear projections from the inputs to query, keys, and values. |
| |
| .. math:: |
| q^\prime = \mathrm{MultiHeadAttentionProj}(q_0,\pmb{h},\mathrm{mask}) |
| |
| The following is computed: |
| |
| .. math:: |
| \begin{aligned} |
| q, \pmb{k}, \pmb{v} &= W^Qq_0, W^K\pmb{h}, W^V\pmb{h}\\ |
| \pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\ |
| h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\ |
| q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right] |
| \end{aligned} |
| |
| if :math:`\pmb{h}` is not given. This module will compute the self attention of :math:`q_0`. |
| |
| .. warning:: |
| The results of the in-projection of query, key, value are |
| slightly different (order of ``1e-6``) with the original implementation. |
| This is due to the numerical accuracy. |
| The two implementations differ by the way of multiplying matrix. |
| Thus, different internal implementation libraries of pytorch are called |
| and the results are slightly different. |
| See the pytorch docs on `numerical accruacy <https://pytorch.org/docs/stable/notes/numerical_accuracy.html>`_ for detail. |
| |
| Args: |
| embedding_dim: dimension of the query, keys, values |
| n_head: number of heads |
| Inputs: q, h, mask |
| * **q** : [batch, n_querys, embedding_dim] |
| * **h**: [batch, n_keys, embedding_dim] |
| * **mask**: [batch, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch,j]==True`` |
| Outputs: out |
| * **out**: [batch, n_querys, embedding_dim] The output of the multi-head attention |
| |
| |
| """ |
|
|
| def __init__(self, embedding_dim, n_heads=8): |
| super(MultiHeadAttentionProj, self).__init__() |
|
|
| self.queryEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
| self.keyEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
| self.valueEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False) |
|
|
| self.MHA = MultiHeadAttention(embedding_dim, n_heads) |
|
|
| def forward(self, q, h=None, mask=torch.zeros([], dtype=torch.bool)): |
|
|
| if h is None: |
| h = q |
|
|
| query = self.queryEncoder(q) |
| key = self.keyEncoder(h) |
| value = self.valueEncoder(h) |
|
|
| out = self.MHA(query, key, value, mask) |
|
|
| return out |
|
|