You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.4 KiB
79 lines
2.4 KiB
"""
|
|
Unit tests for metrics module.
|
|
"""
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import torch
|
|
|
|
import sys
|
|
import os
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
|
|
from oneprompt_seg.utils.metrics import iou, dice_coeff, eval_seg
|
|
|
|
|
|
class TestIoU:
|
|
"""Tests for IoU metric."""
|
|
|
|
def test_perfect_overlap(self):
|
|
"""Test IoU with perfect overlap."""
|
|
pred = np.ones((1, 64, 64), dtype=np.int32)
|
|
target = np.ones((1, 64, 64), dtype=np.int32)
|
|
result = iou(pred, target)
|
|
assert result == pytest.approx(1.0, abs=1e-5)
|
|
|
|
def test_no_overlap(self):
|
|
"""Test IoU with no overlap."""
|
|
pred = np.ones((1, 64, 64), dtype=np.int32)
|
|
target = np.zeros((1, 64, 64), dtype=np.int32)
|
|
result = iou(pred, target)
|
|
assert result == pytest.approx(0.0, abs=1e-5)
|
|
|
|
def test_partial_overlap(self):
|
|
"""Test IoU with partial overlap."""
|
|
pred = np.zeros((1, 64, 64), dtype=np.int32)
|
|
target = np.zeros((1, 64, 64), dtype=np.int32)
|
|
pred[0, :32, :] = 1
|
|
target[0, 16:48, :] = 1
|
|
result = iou(pred, target)
|
|
assert 0 < result < 1
|
|
|
|
|
|
class TestDiceCoeff:
|
|
"""Tests for Dice coefficient."""
|
|
|
|
def test_perfect_overlap(self):
|
|
"""Test Dice with perfect overlap."""
|
|
pred = torch.ones(1, 64, 64)
|
|
target = torch.ones(1, 64, 64)
|
|
result = dice_coeff(pred, target)
|
|
assert result.item() == pytest.approx(1.0, abs=1e-3)
|
|
|
|
def test_no_overlap(self):
|
|
"""Test Dice with no overlap."""
|
|
pred = torch.ones(1, 64, 64)
|
|
target = torch.zeros(1, 64, 64)
|
|
result = dice_coeff(pred, target)
|
|
assert result.item() == pytest.approx(0.0, abs=1e-3)
|
|
|
|
|
|
class TestEvalSeg:
|
|
"""Tests for eval_seg function."""
|
|
|
|
def test_single_channel(self):
|
|
"""Test evaluation with single channel output."""
|
|
pred = torch.rand(2, 1, 64, 64)
|
|
target = torch.rand(2, 1, 64, 64)
|
|
threshold = (0.5,)
|
|
result = eval_seg(pred, target, threshold)
|
|
assert len(result) == 2 # IoU and Dice
|
|
|
|
def test_two_channel(self):
|
|
"""Test evaluation with two channel output."""
|
|
pred = torch.rand(2, 2, 64, 64)
|
|
target = torch.rand(2, 2, 64, 64)
|
|
threshold = (0.5,)
|
|
result = eval_seg(pred, target, threshold)
|
|
assert len(result) == 4 # IoU_d, IoU_c, Dice_d, Dice_c
|