diff --git a/test_main.py b/test_main.py new file mode 100644 index 0000000..12f2c0e --- /dev/null +++ b/test_main.py @@ -0,0 +1,96 @@ +import unittest +from unittest.mock import patch +from main import get_300videos_urls, danmu_about_AI, generate_ + +class TestMain(unittest.TestCase): + """ + 测试 main.py 文件中的函数,确保其在不同情况下的功能正确性。 + """ + + @patch('main.get_300videos_urls') + def test_get_300videos_urls_normal(self, mock_get_urls): + """ + 测试 get_300videos_urls 函数的正常情况。 + """ + mock_get_urls.return_value = ['url1', 'url2', 'url3'] + urls = get_300videos_urls("2024巴黎奥运会") + self.assertEqual(len(urls), 3) + + def test_get_300videos_urls_empty_keyword(self): + """ + 测试 get_300videos_urls 对空关键字的处理。 + """ + with self.assertRaises(ValueError): + get_300videos_urls("") + + @patch('main.get_300videos_urls') + def test_get_300videos_urls_special_chars(self, mock_get_urls): + """ + 测试 get_300videos_urls 函数对特殊字符关键字的处理。 + """ + mock_get_urls.return_value = [] + urls = get_300videos_urls("$%^&*") + self.assertEqual(len(urls), 0) + + @patch('main.get_300videos_urls') + def test_get_300videos_urls_network_error(self, mock_get_urls): + """ + 测试 get_300videos_urls 在网络连接错误情况下的处理。 + """ + mock_get_urls.side_effect = ConnectionError("Network error") + with self.assertRaises(ConnectionError): + get_300videos_urls("2024巴黎奥运会") + + @patch('main.danmu_about_AI') + def test_danmu_about_AI_normal(self, mock_danmu): + """ + 测试 danmu_about_AI 函数的正常行为。 + """ + mock_danmu.return_value = (['danmu1', 'danmu2'], {'智能辅助解说': 2}) + urls = ['url1', 'url2'] + AI_tech = {"智能辅助解说": ["智能", "解说"]} + top_8_danmu, top_8_AI = danmu_about_AI(urls, AI_tech) + self.assertEqual(len(top_8_danmu), 2) + self.assertEqual(top_8_AI['智能辅助解说'], 2) + + @patch('main.danmu_about_AI') + def test_danmu_about_AI_empty_urls(self, mock_danmu): + """ + 测试 danmu_about_AI 对空 URL 列表的处理。 + """ + mock_danmu.return_value = ([], {}) + urls = [] + AI_tech = {"智能辅助解说": ["智能", "解说"]} + top_8_danmu, top_8_AI = danmu_about_AI(urls, AI_tech) + self.assertEqual(len(top_8_danmu), 0) + self.assertEqual(len(top_8_AI), 0) + + @patch('main.generate_') + def test_generate_normal(self, mock_generate): + """ + 测试 generate_ 函数的正常行为。 + """ + mock_generate.return_value = None + result = generate_("danmu_data.xlsx") + self.assertIsNone(result) + + @patch('main.generate_') + def test_generate_invalid_path(self, mock_generate): + """ + 测试 generate_ 对无效文件路径的处理。 + """ + mock_generate.side_effect = FileNotFoundError("Invalid file path") + with self.assertRaises(FileNotFoundError): + generate_("invalid_path.xlsx") + + @patch('main.generate_') + def test_generate_invalid_format(self, mock_generate): + """ + 测试 generate_ 对无效文件格式的处理。 + """ + mock_generate.side_effect = ValueError("Invalid file format") + with self.assertRaises(ValueError): + generate_("danmu_data.txt") + +if __name__ == '__main__': + unittest.main()