You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

79 lines
2.4 KiB

"""
Unit tests for metrics module.
"""
import pytest
import numpy as np
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from oneprompt_seg.utils.metrics import iou, dice_coeff, eval_seg
class TestIoU:
"""Tests for IoU metric."""
def test_perfect_overlap(self):
"""Test IoU with perfect overlap."""
pred = np.ones((1, 64, 64), dtype=np.int32)
target = np.ones((1, 64, 64), dtype=np.int32)
result = iou(pred, target)
assert result == pytest.approx(1.0, abs=1e-5)
def test_no_overlap(self):
"""Test IoU with no overlap."""
pred = np.ones((1, 64, 64), dtype=np.int32)
target = np.zeros((1, 64, 64), dtype=np.int32)
result = iou(pred, target)
assert result == pytest.approx(0.0, abs=1e-5)
def test_partial_overlap(self):
"""Test IoU with partial overlap."""
pred = np.zeros((1, 64, 64), dtype=np.int32)
target = np.zeros((1, 64, 64), dtype=np.int32)
pred[0, :32, :] = 1
target[0, 16:48, :] = 1
result = iou(pred, target)
assert 0 < result < 1
class TestDiceCoeff:
"""Tests for Dice coefficient."""
def test_perfect_overlap(self):
"""Test Dice with perfect overlap."""
pred = torch.ones(1, 64, 64)
target = torch.ones(1, 64, 64)
result = dice_coeff(pred, target)
assert result.item() == pytest.approx(1.0, abs=1e-3)
def test_no_overlap(self):
"""Test Dice with no overlap."""
pred = torch.ones(1, 64, 64)
target = torch.zeros(1, 64, 64)
result = dice_coeff(pred, target)
assert result.item() == pytest.approx(0.0, abs=1e-3)
class TestEvalSeg:
"""Tests for eval_seg function."""
def test_single_channel(self):
"""Test evaluation with single channel output."""
pred = torch.rand(2, 1, 64, 64)
target = torch.rand(2, 1, 64, 64)
threshold = (0.5,)
result = eval_seg(pred, target, threshold)
assert len(result) == 2 # IoU and Dice
def test_two_channel(self):
"""Test evaluation with two channel output."""
pred = torch.rand(2, 2, 64, 64)
target = torch.rand(2, 2, 64, 64)
threshold = (0.5,)
result = eval_seg(pred, target, threshold)
assert len(result) == 4 # IoU_d, IoU_c, Dice_d, Dice_c