Source code for pypots.nn.modules.tide.layers

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)


[docs] class ResBlock(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias) self.fc3 = nn.Linear(input_dim, output_dim, bias=bias) self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() self.ln = LayerNorm(output_dim, bias=bias)
[docs] def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.dropout(out) out = out + self.fc3(x) out = self.ln(out) return out