Source code for h3.models.opti_utils

from __future__ import annotations

import os
import numpy as np
from concurrent.futures import ThreadPoolExecutor

from numba import jit

from PIL import Image
from tqdm import tqdm
from h3.utils.directories import get_processed_data_dir

from typing import Callable


@jit(forceobj=True)
def open_img(path: str, transform: Callable):
	img = Image.open(path)
	img = transform(img)
	# print(img.shape)
	return img


[docs]def load_full_ram(lst_path: list, transform: Callable) -> list: all_object = [] transforms = [transform for _ in range(len(lst_path))] with ThreadPoolExecutor() as pool: for results in list(tqdm(pool.map(open_img, lst_path, transforms), total=len(lst_path))): all_object.append(results) return all_object
[docs]def main(): import glob import torch from torchvision.models import Swin_V2_B_Weights preprocessing = Swin_V2_B_Weights.IMAGENET1K_V1.transforms() zoom_levels = ["1", "2", "4", "0.5"] img_path = os.path.join(get_processed_data_dir(), "processed_xbd", "geotiffs_zoom", "images") img_paths = glob.glob(img_path + '/**/*.png', recursive=True) size = len(img_paths) all = load_full_ram(lst_path=img_paths[:int(size//96)], transform=preprocessing) print(len(all))
if __name__ == "__main__": main()