import glob import os from collections import defaultdict from html.parser import HTMLParser from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image from .vision import VisionDataset class Flickr8kParser(HTMLParser): """Parser for extracting captions from the Flickr8k dataset web page.""" def __init__(self, root: Union[str, Path]) -> None: super().__init__() self.root = root # Data structure to store captions self.annotations: Dict[str, List[str]] = {} # State variables self.in_table = False self.current_tag: Optional[str] = None self.current_img: Optional[str] = None def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: self.current_tag = tag if tag == "table": self.in_table = True def handle_endtag(self, tag: str) -> None: self.current_tag = None if tag == "table": self.in_table = False def handle_data(self, data: str) -> None: if self.in_table: if data == "Image Not Found": self.current_img = None elif self.current_tag == "a": img_id = data.split("/")[-2] img_id = os.path.join(self.root, img_id + "_*.jpg") img_id = glob.glob(img_id)[0] self.current_img = img_id self.annotations[img_id] = [] elif self.current_tag == "li" and self.current_img: img_id = self.current_img self.annotations[img_id].append(data.strip()) class Flickr8k(VisionDataset): """`Flickr8k Entities `_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ def __init__( self, root: Union[str, Path], ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict parser = Flickr8kParser(self.root) with open(self.ann_file) as fh: parser.feed(fh.read()) self.annotations = parser.annotations self.ids = list(sorted(self.annotations.keys())) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """ img_id = self.ids[index] # Image img = Image.open(img_id).convert("RGB") if self.transform is not None: img = self.transform(img) # Captions target = self.annotations[img_id] if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.ids) class Flickr30k(VisionDataset): """`Flickr30k Entities `_ Dataset. Args: root (str or ``pathlib.Path``): Root directory where images are downloaded to. ann_file (string): Path to annotation file. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ def __init__( self, root: str, ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict self.annotations = defaultdict(list) with open(self.ann_file) as fh: for line in fh: img_id, caption = line.strip().split("\t") self.annotations[img_id[:-2]].append(caption) self.ids = list(sorted(self.annotations.keys())) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: Tuple (image, target). target is a list of captions for the image. """ img_id = self.ids[index] # Image filename = os.path.join(self.root, img_id) img = Image.open(filename).convert("RGB") if self.transform is not None: img = self.transform(img) # Captions target = self.annotations[img_id] if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.ids)