bio-attention.attention Reference
Here be modules that will make up building blocks of transformers
- class bio_attention.attention.Attention(dropout: float = 0.0, enable_math: bool = True, enable_flash: bool = True, enable_mem_efficient: bool = True, **kwargs)
Scaled-dot product attention operator. For more information on kernels: see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- Parameters:
dropout (float, optional) – dropout rate in self attention matrix, by default 0.0
enable_math (bool, optional) – allow PyTorch C++ implementation, by default True
enable_flash (bool, optional) – allow FlashAttention implementation, by default True
enable_mem_efficient (bool, optional) – allow Memory-Efficient implementation, by default True
- forward(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor
Forward pass
- Parameters:
q (torch.Tensor) – (B, *, L1, NH, H)
k (torch.Tensor) – (B, *, L2, NH, H)
v (torch.Tensor) – (B, *, L2, NH, H)
mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None
causal (bool, optional) – Perform causal attention. Unlike default pytorch implementation, both a mask and causal can be used jointly, by default False
- Returns:
(B, *, L1, NH, H)
- Return type:
torch.Tensor
- class bio_attention.attention.RandomAttention(n_random_keys: int = 64, dropout: float = 0.0, materialize_full: bool = False, **kwargs)
Scaled-dot product attention operator that only randomly attends on a number of keys per query. Supports two versions: one that materializes the full matrix and, hence, scales quadratically with sequence length. In essence, this is default attention with random masks. The other version is memory efficient, scaling linearly with sequence length, but has a lower base efficiency because of the extra steps taken.
- Parameters:
n_random_keys (int, optional) – number of keys every query should attend to, by default 64
dropout (float, optional) – dropout rate in self attention matrix, by default 0.0
materialize_full (bool, optional) – whether to materialize full attention matrix, by default False
- forward_indexed(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor
Memory-efficient forward pass NOTE: for the moment, causal attention with input masks are not implemented.
- Parameters:
- Returns:
(B, *, L1, NH, H)
- Return type:
torch.Tensor
- class bio_attention.attention.WindowAttention(window: int = 15, dropout: float = 0.0, materialize_full: bool = False, **kwargs)
Scaled-dot product attention operator that only attends on a local window of keys per query. Supports two versions: one that materializes the full matrix and, hence, scales quadratically with sequence length. In essence, this is default attention with a mask. The other version is memory efficient, scaling linearly with sequence length, but has a lower base efficiency because of the extra steps taken.
- Parameters:
window (int, optional) – Window size, analogous to kernel size in convolutions, should be odd, by default 15
dropout (float, optional) – dropout rate in self attention matrix, by default 0.0
materialize_full (bool, optional) – whether to materialize full attention matrix, by default False
- forward_sliced(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor
Memory-efficient forward pass NOTE: q and k need to have the same sequence length L1 = L2 NOTE: is incompatible with user-defined masks for the time being.
- Parameters:
- Returns:
(B, *, L1, NH, H)
- Return type:
torch.Tensor
- forward_naive(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor
Naive (masking) forward pass NOTE: q and k need to have the same sequence length L1 = L2 NOTE: is incompatible with user-defined masks for the time being.
- Parameters:
- Returns:
(B, *, L1, NH, H)
- Return type:
torch.Tensor
- class bio_attention.attention.Transformer(depth: int, dim: int, nh: int, attentiontype: Literal['vanilla', 'random', 'window'] = 'vanilla', attention_args: dict = {}, plugintype: Literal['none', 'sinusoidal', 'learned', 'learnedcont', 'rotary', 'ALiBi', 'DPB', 'XL'] | List[Literal['none', 'sinusoidal', 'learned', 'learnedcont', 'rotary', 'ALiBi', 'DPB', 'XL']] = 'none', plugin_args: dict | List[dict] = {}, only_apply_plugin_at_first: bool | List[bool] = False, dropout: float = 0.2, glu_ff: bool = True, activation: Literal['relu', 'gelu', 'swish'] = 'swish')
Transformer network chaining multiple transformer layers
- Parameters:
depth (int) – number of transformer blocks to use
dim (int) – input and output hidden dimension of x
nh (int) – number of heads, dim should be divisible by this number
attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”
attention_args (dict, optional) – args passed to the attention operator init, by default {}
plugintype (Union[EncodingType, List[EncodingType]], optional) – positional bias plugin, by default “none”
plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}
only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False
dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2
glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True
activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”
- forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, causal: bool = False, **mod_kwargs) Tensor
Forward pass
- Parameters:
x (torch.Tensor) – (B, *, L, H)
pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.
mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.
causal (bool, optional) – Perform causal attention, by default False
- Returns:
(B, *, L, H)
- Return type:
torch.Tensor
- class bio_attention.attention.TransformerEncoder(*args, **kwargs)
Transformer Encoder network chaining multiple transformer layers Equal to Transformer, only difference is that Causal = False is automatically decided in forward pass
- Parameters:
depth (int) – number of transformer blocks to use
dim (int) – input and output hidden dimension of x
nh (int) – number of heads, dim should be divisible by this number
attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”
attention_args (dict, optional) – args passed to the attention operator init, by default {}
plugintype (Union[EncodingType, List[EncodingType]], optional) – positional bias plugin, by default “none”
plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}
only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False
dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2
glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True
activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”
- forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, **mod_kwargs)
Forward pass
- Parameters:
x (torch.Tensor) – (B, *, L, H)
pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.
mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.
causal (bool, optional) – Perform causal attention, by default False
- Returns:
(B, *, L, H)
- Return type:
torch.Tensor
- class bio_attention.attention.TransformerDecoder(*args, **kwargs)
Transformer Decoder network chaining multiple transformer layers Equal to Transformer, only difference is that Causal = True is automatically decided in forward pass
- Parameters:
depth (int) – number of transformer blocks to use
dim (int) – input and output hidden dimension of x
nh (int) – number of heads, dim should be divisible by this number
attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”
attention_args (dict, optional) – args passed to the attention operator init, by default {}
plugintype (Union[EncodingType, List[EncodingType]] = "none",, optional) – positional bias plugin, by default “none”
plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}
only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False
dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2
glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True
activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”
- forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, **mod_kwargs)
Forward pass
- Parameters:
x (torch.Tensor) – (B, *, L, H)
pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.
mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.
causal (bool, optional) – Perform causal attention, by default False
- Returns:
(B, *, L, H)
- Return type:
torch.Tensor
- class bio_attention.attention.Aggregator(method: Literal['mean', 'max', 'cls'] = 'max')
Aggregator module. Can be used to get a single vector from a sequence-valued input.
- Parameters:
method (Literal["mean", "max", "cls"], optional) – aggregation method, by default “max”