import unittest from unittest.mock import patch from main import get_300videos_urls, danmu_about_AI, generate_ class TestMain(unittest.TestCase): """ 测试 main.py 文件中的函数,确保其在不同情况下的功能正确性。 """ @patch('main.get_300videos_urls') def test_get_300videos_urls_normal(self, mock_get_urls): """ 测试 get_300videos_urls 函数的正常情况。 """ mock_get_urls.return_value = ['url1', 'url2', 'url3'] urls = get_300videos_urls("2024巴黎奥运会") self.assertEqual(len(urls), 3) def test_get_300videos_urls_empty_keyword(self): """ 测试 get_300videos_urls 对空关键字的处理。 """ with self.assertRaises(ValueError): get_300videos_urls("") @patch('main.get_300videos_urls') def test_get_300videos_urls_special_chars(self, mock_get_urls): """ 测试 get_300videos_urls 函数对特殊字符关键字的处理。 """ mock_get_urls.return_value = [] urls = get_300videos_urls("$%^&*") self.assertEqual(len(urls), 0) @patch('main.get_300videos_urls') def test_get_300videos_urls_network_error(self, mock_get_urls): """ 测试 get_300videos_urls 在网络连接错误情况下的处理。 """ mock_get_urls.side_effect = ConnectionError("Network error") with self.assertRaises(ConnectionError): get_300videos_urls("2024巴黎奥运会") @patch('main.danmu_about_AI') def test_danmu_about_AI_normal(self, mock_danmu): """ 测试 danmu_about_AI 函数的正常行为。 """ mock_danmu.return_value = (['danmu1', 'danmu2'], {'智能辅助解说': 2}) urls = ['url1', 'url2'] AI_tech = {"智能辅助解说": ["智能", "解说"]} top_8_danmu, top_8_AI = danmu_about_AI(urls, AI_tech) self.assertEqual(len(top_8_danmu), 2) self.assertEqual(top_8_AI['智能辅助解说'], 2) @patch('main.danmu_about_AI') def test_danmu_about_AI_empty_urls(self, mock_danmu): """ 测试 danmu_about_AI 对空 URL 列表的处理。 """ mock_danmu.return_value = ([], {}) urls = [] AI_tech = {"智能辅助解说": ["智能", "解说"]} top_8_danmu, top_8_AI = danmu_about_AI(urls, AI_tech) self.assertEqual(len(top_8_danmu), 0) self.assertEqual(len(top_8_AI), 0) @patch('main.generate_') def test_generate_normal(self, mock_generate): """ 测试 generate_ 函数的正常行为。 """ mock_generate.return_value = None result = generate_("danmu_data.xlsx") self.assertIsNone(result) @patch('main.generate_') def test_generate_invalid_path(self, mock_generate): """ 测试 generate_ 对无效文件路径的处理。 """ mock_generate.side_effect = FileNotFoundError("Invalid file path") with self.assertRaises(FileNotFoundError): generate_("invalid_path.xlsx") @patch('main.generate_') def test_generate_invalid_format(self, mock_generate): """ 测试 generate_ 对无效文件格式的处理。 """ mock_generate.side_effect = ValueError("Invalid file format") with self.assertRaises(ValueError): generate_("danmu_data.txt") if __name__ == '__main__': unittest.main()