Source code for osc.models.utils

"""
Misc modules.
"""

from typing import Tuple

import einops.layers.torch
from torch import nn as nn


[docs]class MLP(nn.Module): def __init__( self, in_features: int, hidden_features: int = None, out_features: int = None, activation=nn.GELU, hidden_bias=True, out_bias=True, dropout=0.0, ): super().__init__() if hidden_features is None: hidden_features = in_features if out_features is None: out_features = in_features if not isinstance(dropout, Tuple): dropout = (dropout, dropout) self.fc1 = nn.Linear(in_features, hidden_features, bias=hidden_bias) self.act = activation() self.drop1 = nn.Dropout(dropout[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=out_bias) self.drop2 = nn.Dropout(dropout[1])
[docs] def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x
global_avg_pool = einops.layers.torch.Reduce("B N C -> B C", "mean") global_max_pool = einops.layers.torch.Reduce("B N C -> B C", "max")