Source code for h3.models.pre_train

import os.path

import rasterio
import numpy as np
import matplotlib.pyplot as plt
import torch

from torchvision.models import ViT_L_16_Weights, vit_l_16
from torchvision.models import swin_v2_b, Swin_V2_B_Weights
from h3.utils.directories import get_xbd_hurricane_dir


[docs]def load_image() -> np.ndarray: hurricane_dir = get_xbd_hurricane_dir() imagepath = os.path.join(hurricane_dir, "hold", "images", "hurricane-florence_00000236_post_disaster.tif") src = rasterio.open(imagepath) image = src.read() return image
[docs]def load_model(): # model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT) model = swin_v2_b(weights=Swin_V2_B_Weights.DEFAULT) print(type(model)) model.eval() return model
[docs]def main(): # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("loading model") model = load_model() print("loading image") image = load_image() print(image.shape) image = torch.as_tensor(image) # preprocess = ViT_L_16_Weights.IMAGENET1K_V1.transforms() preprocess = Swin_V2_B_Weights.IMAGENET1K_V1.transforms() batch = preprocess(image).unsqueeze(0) out = model(batch).squeeze(0) print(out) print(out.shape)
if __name__ == "__main__": main()