Source code for pypots.nn.modules.reformer.autoencoder

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from .layers import ReformerLayer


[docs] class ReformerEncoder(nn.Module): def __init__( self, n_steps, n_layers, d_model, n_heads, bucket_size, n_hashes, causal, d_ffn, dropout, ): super().__init__() assert n_steps % (bucket_size * 2) == 0, ( f"Sequence length ({n_steps}) needs to be divisible by target bucket size {bucket_size} x 2" ) self.enc_layer_stack = nn.ModuleList( [ ReformerLayer( d_model, n_heads, bucket_size, n_hashes, causal, d_ffn, dropout, ) for _ in range(n_layers) ] )
[docs] def forward(self, x: torch.Tensor): enc_output = x for layer in self.enc_layer_stack: enc_output = layer(enc_output) return enc_output