bio-attention.attention Reference

Here be modules that will make up building blocks of transformers

class bio_attention.attention.Attention(dropout: float = 0.0, enable_math: bool = True, enable_flash: bool = True, enable_mem_efficient: bool = True, **kwargs)

Scaled-dot product attention operator. For more information on kernels: see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Parameters:
  • dropout (float, optional) – dropout rate in self attention matrix, by default 0.0

  • enable_math (bool, optional) – allow PyTorch C++ implementation, by default True

  • enable_flash (bool, optional) – allow FlashAttention implementation, by default True

  • enable_mem_efficient (bool, optional) – allow Memory-Efficient implementation, by default True

forward(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor

Forward pass

Parameters:
  • q (torch.Tensor) – (B, *, L1, NH, H)

  • k (torch.Tensor) – (B, *, L2, NH, H)

  • v (torch.Tensor) – (B, *, L2, NH, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None

  • causal (bool, optional) – Perform causal attention. Unlike default pytorch implementation, both a mask and causal can be used jointly, by default False

Returns:

(B, *, L1, NH, H)

Return type:

torch.Tensor

class bio_attention.attention.RandomAttention(n_random_keys: int = 64, dropout: float = 0.0, materialize_full: bool = False, **kwargs)

Scaled-dot product attention operator that only randomly attends on a number of keys per query. Supports two versions: one that materializes the full matrix and, hence, scales quadratically with sequence length. In essence, this is default attention with random masks. The other version is memory efficient, scaling linearly with sequence length, but has a lower base efficiency because of the extra steps taken.

Parameters:
  • n_random_keys (int, optional) – number of keys every query should attend to, by default 64

  • dropout (float, optional) – dropout rate in self attention matrix, by default 0.0

  • materialize_full (bool, optional) – whether to materialize full attention matrix, by default False

forward_indexed(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor

Memory-efficient forward pass NOTE: for the moment, causal attention with input masks are not implemented.

Parameters:
  • q (torch.Tensor) – (B, *, L1, NH, H)

  • k (torch.Tensor) – (B, *, L2, NH, H)

  • v (torch.Tensor) – (B, *, L2, NH, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L1, NH, H)

Return type:

torch.Tensor

forward_naive(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor

Naive (random masking) forward pass NOTE: for the moment, is incompatible with input masks

Parameters:
  • q (torch.Tensor) – (B, *, L1, NH, H)

  • k (torch.Tensor) – (B, *, L2, NH, H)

  • v (torch.Tensor) – (B, *, L2, NH, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L1, NH, H)

Return type:

torch.Tensor

class bio_attention.attention.WindowAttention(window: int = 15, dropout: float = 0.0, materialize_full: bool = False, **kwargs)

Scaled-dot product attention operator that only attends on a local window of keys per query. Supports two versions: one that materializes the full matrix and, hence, scales quadratically with sequence length. In essence, this is default attention with a mask. The other version is memory efficient, scaling linearly with sequence length, but has a lower base efficiency because of the extra steps taken.

Parameters:
  • window (int, optional) – Window size, analogous to kernel size in convolutions, should be odd, by default 15

  • dropout (float, optional) – dropout rate in self attention matrix, by default 0.0

  • materialize_full (bool, optional) – whether to materialize full attention matrix, by default False

forward_sliced(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor

Memory-efficient forward pass NOTE: q and k need to have the same sequence length L1 = L2 NOTE: is incompatible with user-defined masks for the time being.

Parameters:
  • q (torch.Tensor) – (B, *, L1, NH, H)

  • k (torch.Tensor) – (B, *, L2, NH, H)

  • v (torch.Tensor) – (B, *, L2, NH, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L1, NH, H)

Return type:

torch.Tensor

forward_naive(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None, causal: bool = False) Tensor

Naive (masking) forward pass NOTE: q and k need to have the same sequence length L1 = L2 NOTE: is incompatible with user-defined masks for the time being.

Parameters:
  • q (torch.Tensor) – (B, *, L1, NH, H)

  • k (torch.Tensor) – (B, *, L2, NH, H)

  • v (torch.Tensor) – (B, *, L2, NH, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, NH, L1, L2), by default None

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L1, NH, H)

Return type:

torch.Tensor

class bio_attention.attention.Transformer(depth: int, dim: int, nh: int, attentiontype: Literal['vanilla', 'random', 'window'] = 'vanilla', attention_args: dict = {}, plugintype: Literal['none', 'sinusoidal', 'learned', 'learnedcont', 'rotary', 'ALiBi', 'DPB', 'XL'] | List[Literal['none', 'sinusoidal', 'learned', 'learnedcont', 'rotary', 'ALiBi', 'DPB', 'XL']] = 'none', plugin_args: dict | List[dict] = {}, only_apply_plugin_at_first: bool | List[bool] = False, dropout: float = 0.2, glu_ff: bool = True, activation: Literal['relu', 'gelu', 'swish'] = 'swish')

Transformer network chaining multiple transformer layers

Parameters:
  • depth (int) – number of transformer blocks to use

  • dim (int) – input and output hidden dimension of x

  • nh (int) – number of heads, dim should be divisible by this number

  • attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”

  • attention_args (dict, optional) – args passed to the attention operator init, by default {}

  • plugintype (Union[EncodingType, List[EncodingType]], optional) – positional bias plugin, by default “none”

  • plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}

  • only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False

  • dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2

  • glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True

  • activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”

forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, causal: bool = False, **mod_kwargs) Tensor

Forward pass

Parameters:
  • x (torch.Tensor) – (B, *, L, H)

  • pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.

  • mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L, H)

Return type:

torch.Tensor

class bio_attention.attention.TransformerEncoder(*args, **kwargs)

Transformer Encoder network chaining multiple transformer layers Equal to Transformer, only difference is that Causal = False is automatically decided in forward pass

Parameters:
  • depth (int) – number of transformer blocks to use

  • dim (int) – input and output hidden dimension of x

  • nh (int) – number of heads, dim should be divisible by this number

  • attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”

  • attention_args (dict, optional) – args passed to the attention operator init, by default {}

  • plugintype (Union[EncodingType, List[EncodingType]], optional) – positional bias plugin, by default “none”

  • plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}

  • only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False

  • dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2

  • glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True

  • activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”

forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, **mod_kwargs)

Forward pass

Parameters:
  • x (torch.Tensor) – (B, *, L, H)

  • pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.

  • mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L, H)

Return type:

torch.Tensor

class bio_attention.attention.TransformerDecoder(*args, **kwargs)

Transformer Decoder network chaining multiple transformer layers Equal to Transformer, only difference is that Causal = True is automatically decided in forward pass

Parameters:
  • depth (int) – number of transformer blocks to use

  • dim (int) – input and output hidden dimension of x

  • nh (int) – number of heads, dim should be divisible by this number

  • attentiontype (Literal["vanilla", "random", "window"], optional) – attention operator, by default “vanilla”

  • attention_args (dict, optional) – args passed to the attention operator init, by default {}

  • plugintype (Union[EncodingType, List[EncodingType]] = "none",, optional) – positional bias plugin, by default “none”

  • plugin_args (Union[dict, List[dict]], optional) – arguments passed to positional bias init, by default {}

  • only_apply_plugin_at_first (Union[bool, List[bool]], optional) – only apply positional bias at the first layer, by default False

  • dropout (float, optional) – dropout in feedforward layers. Take note that attention matrix dropout is controlled via attention_args, by default 0.2

  • glu_ff (bool, optional) – whether to use gated linear feedforward network, by default True

  • activation (Literal["relu", "gelu", "swish"], optional) – activation, by default “swish”

forward(x: Tensor, pos: Tensor | None = None, mask: Tensor | None = None, **mod_kwargs)

Forward pass

Parameters:
  • x (torch.Tensor) – (B, *, L, H)

  • pos (Optional[torch.Tensor], optional) – (B, *, L) or (B, *, L-x), by default None If pos has a smaller sequence length than x, it is assumed x has extra tokens added in the beginning of its sequence such as CLS tokens. In this case, there is no position for these tokens. Positional biases will make sure those tokens do not partake in positional encoding.

  • mask (Optional[torch.Tensor], optional) – By default None, but can be either: (1) (B, * L) or (B, * L-x). In this case, expects a boolean mask. This type of mask will be copied to (B, * NH, L, L) in a way such that no tokens can attend to positions indicated by False. This type of mask will extrapolate CLS tokens to not attend on positions indicated with False. (2) (B, * L, L) or (B, * L-x, L-x). In this case, can either be floating point or boolean mask. This type of mask will extrapolate CLS token to attend on all tokens. For this type of mask, the same mask is applied over all heads (3) (B, * NH, L, L) or (B, * NH, L-x, L-x) Ditto as previous case, but for this type of mask, different biases/masks can be applied per head.

  • causal (bool, optional) – Perform causal attention, by default False

Returns:

(B, *, L, H)

Return type:

torch.Tensor

class bio_attention.attention.Aggregator(method: Literal['mean', 'max', 'cls'] = 'max')

Aggregator module. Can be used to get a single vector from a sequence-valued input.

Parameters:

method (Literal["mean", "max", "cls"], optional) – aggregation method, by default “max”

forward(x: Tensor, mask: Tensor | None = None) Tensor

Forward pass

Parameters:
  • x (torch.Tensor) – (B, *, L, H)

  • mask (Optional[torch.Tensor], optional) – (B, *, L), by default None

Returns:

(B, *, H)

Return type:

torch.Tensor