You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
1.6 KiB

"""
Unit tests for data utilities.
"""
import pytest
import numpy as np
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from oneprompt_seg.utils.data_utils import random_click, generate_click_prompt
class TestRandomClick:
"""Tests for random_click function."""
def test_click_on_foreground(self):
"""Test that click is generated on foreground."""
mask = np.zeros((64, 64))
mask[20:40, 20:40] = 1
pt = random_click(mask, point_labels=1, inout=1)
assert 20 <= pt[0] < 40
assert 20 <= pt[1] < 40
def test_empty_mask(self):
"""Test click generation with empty mask."""
mask = np.zeros((64, 64))
pt = random_click(mask, point_labels=1, inout=1)
# Should return center when no valid points
assert pt[0] == 32
assert pt[1] == 32
def test_click_coordinates_shape(self):
"""Test that click returns correct shape."""
mask = np.ones((64, 64))
pt = random_click(mask, point_labels=1, inout=1)
assert len(pt) == 2
class TestGenerateClickPrompt:
"""Tests for generate_click_prompt function."""
def test_output_shapes(self):
"""Test output tensor shapes."""
img = torch.rand(2, 3, 64, 64, 4)
msk = torch.rand(2, 1, 64, 64, 4)
msk = (msk > 0.5).float()
out_img, pt, out_msk = generate_click_prompt(img, msk)
assert out_img.shape == img.shape
assert pt.shape[0] == 2 # batch size
assert pt.shape[-1] == 4 # depth
assert out_msk.shape[0] == 2 # batch size