Source code for h3.models.vision_transformer

import torch.nn as nn

from torchvision.models import ViT_L_16_Weights, vit_l_16
from torchvision.models import swin_v2_b, Swin_V2_B_Weights

from torchvision.models.swin_transformer import SwinTransformer
from typing import Callable


[docs]def get_model(name) -> nn.Module: model = swin_v2_b(weights=Swin_V2_B_Weights.DEFAULT) model.eval() return model
[docs]def get_preprocess(name) -> Callable: return name.weigth.transforms()