""" Unit tests for data utilities. """ import pytest import numpy as np import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from oneprompt_seg.utils.data_utils import random_click, generate_click_prompt class TestRandomClick: """Tests for random_click function.""" def test_click_on_foreground(self): """Test that click is generated on foreground.""" mask = np.zeros((64, 64)) mask[20:40, 20:40] = 1 pt = random_click(mask, point_labels=1, inout=1) assert 20 <= pt[0] < 40 assert 20 <= pt[1] < 40 def test_empty_mask(self): """Test click generation with empty mask.""" mask = np.zeros((64, 64)) pt = random_click(mask, point_labels=1, inout=1) # Should return center when no valid points assert pt[0] == 32 assert pt[1] == 32 def test_click_coordinates_shape(self): """Test that click returns correct shape.""" mask = np.ones((64, 64)) pt = random_click(mask, point_labels=1, inout=1) assert len(pt) == 2 class TestGenerateClickPrompt: """Tests for generate_click_prompt function.""" def test_output_shapes(self): """Test output tensor shapes.""" img = torch.rand(2, 3, 64, 64, 4) msk = torch.rand(2, 1, 64, 64, 4) msk = (msk > 0.5).float() out_img, pt, out_msk = generate_click_prompt(img, msk) assert out_img.shape == img.shape assert pt.shape[0] == 2 # batch size assert pt.shape[-1] == 4 # depth assert out_msk.shape[0] == 2 # batch size