You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.6 KiB
57 lines
1.6 KiB
"""
|
|
Unit tests for data utilities.
|
|
"""
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import torch
|
|
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
from oneprompt_seg.utils.data_utils import random_click, generate_click_prompt
|
|
|
|
|
|
class TestRandomClick:
|
|
"""Tests for random_click function."""
|
|
|
|
def test_click_on_foreground(self):
|
|
"""Test that click is generated on foreground."""
|
|
mask = np.zeros((64, 64))
|
|
mask[20:40, 20:40] = 1
|
|
pt = random_click(mask, point_labels=1, inout=1)
|
|
assert 20 <= pt[0] < 40
|
|
assert 20 <= pt[1] < 40
|
|
|
|
def test_empty_mask(self):
|
|
"""Test click generation with empty mask."""
|
|
mask = np.zeros((64, 64))
|
|
pt = random_click(mask, point_labels=1, inout=1)
|
|
# Should return center when no valid points
|
|
assert pt[0] == 32
|
|
assert pt[1] == 32
|
|
|
|
def test_click_coordinates_shape(self):
|
|
"""Test that click returns correct shape."""
|
|
mask = np.ones((64, 64))
|
|
pt = random_click(mask, point_labels=1, inout=1)
|
|
assert len(pt) == 2
|
|
|
|
|
|
class TestGenerateClickPrompt:
|
|
"""Tests for generate_click_prompt function."""
|
|
|
|
def test_output_shapes(self):
|
|
"""Test output tensor shapes."""
|
|
img = torch.rand(2, 3, 64, 64, 4)
|
|
msk = torch.rand(2, 1, 64, 64, 4)
|
|
msk = (msk > 0.5).float()
|
|
|
|
out_img, pt, out_msk = generate_click_prompt(img, msk)
|
|
|
|
assert out_img.shape == img.shape
|
|
assert pt.shape[0] == 2 # batch size
|
|
assert pt.shape[-1] == 4 # depth
|
|
assert out_msk.shape[0] == 2 # batch size
|