From 809cee180fb954f380aa21a3a065e63fa55df1bb Mon Sep 17 00:00:00 2001 From: p5lu8poqt <785234984@qq.com> Date: Tue, 30 Dec 2025 15:08:57 +0800 Subject: [PATCH] ADD file via upload --- eval_hle.py | 713 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 713 insertions(+) create mode 100644 eval_hle.py diff --git a/eval_hle.py b/eval_hle.py new file mode 100644 index 0000000..263af1d --- /dev/null +++ b/eval_hle.py @@ -0,0 +1,713 @@ +import os +import random +import time +import json +import requests +import asyncio +import subprocess +from tqdm import tqdm +from transformers import AutoTokenizer +import sys +REPO_PATH = os.getenv("REPO_PATH") +sys.path.append(REPO_PATH) +from LLM_CALL import get_llm_response +import multiprocessing as mp +import argparse +import logging +from openai import OpenAI +logging.disable(logging.CRITICAL) + +MODEL_NAME = None +my_output_dir = None +MAX_ROUNDS = None +MODEL_TYPE = None +MODEL_MAPPING = None +TOOL_PRICING = None +vllm_model_configs = None +with open('tools.json') as f: + raw_tools = json.load(f) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") +# Provide your api key +oss_client = OpenAI( + base_url = "https://integrate.api.nvidia.com/v1", + api_key = os.getenv("OSS_KEY") +) + +MODEL_MAPPING = { + "search-1": "gpt-5", + "search-2": "gpt-5-mini", + "search-3": "Qwen/Qwen3-32B", + "reasoner-1": "gpt-5", + "reasoner-2": "gpt-5-mini", + "reasoner-3": "Qwen/Qwen2.5-Coder-32B-Instruct", + "answer-math-1": "Qwen/Qwen2.5-Math-72B-Instruct", + "answer-math-2": "Qwen/Qwen2.5-Math-7B-Instruct", + "answer-1": "gpt-5", + "answer-2": "gpt-5-mini", + "answer-3": "meta-llama/Llama-3.3-70B-Instruct", + "answer-4": "Qwen/Qwen3-32B" +} +TOOL_PRICING = { + "gpt-5": { + "input_tokens_per_million": 1.25/10000000, + "output_tokens_per_million": 10/1000000 + }, + "gpt-5-mini": { + "input_tokens_per_million": 0.25/10000000, + "output_tokens_per_million": 2/1000000 + }, + "Qwen/Qwen3-32B": { + "input_tokens_per_million": 0.8/1000000, + "output_tokens_per_million": 0.8/1000000 + }, + "Qwen/Qwen2.5-Coder-32B-Instruct": { + "input_tokens_per_million": 0.8/1000000, + "output_tokens_per_million": 0.8/1000000 + }, + "Qwen/Qwen2.5-Math-72B-Instruct": { + "input_tokens_per_million": 0.9/1000000, + "output_tokens_per_million": 0.9/1000000 + }, + "Qwen/Qwen2.5-Math-7B-Instruct": { + "input_tokens_per_million": 0.2/1000000, + "output_tokens_per_million": 0.2/1000000 + }, + "meta-llama/Llama-3.3-70B-Instruct": { + "input_tokens_per_million": 0.9/1000000, + "output_tokens_per_million": 0.9/1000000 + }, + "Qwen/Qwen3-8B": { + "input_tokens_per_million": 0.2/1000000, + "output_tokens_per_million": 0.2/1000000 + }, + "code_interpreter_per_second": 0.0000083, + "tavily": { + "search": 0.01, + "extract": 0.002 + }, +} +ALL_TOOLS = { + "enhance_reasoning": { + 'model': ["reasoner-1", "reasoner-2", "reasoner-3"] + }, + "answer": { + 'model': ["answer-math-1", "answer-math-2", "answer-1", "answer-2", "answer-3", "answer-4"] + }, + "search": { + "model": ["search-1", "search-2", "search-3"] + }, +} + +def cut_seq(seq,l): + if len(seq)==0: + return { + 'effective_length': 0, + 'string_after_cut': '' + } + token_ids = tokenizer(seq)['input_ids'] + rs = tokenizer.batch_decode(token_ids[-l:], skip_special_tokens=True) + return { + 'effective_length': len(token_ids), + 'string_after_cut': ''.join(rs) + } + +def call_tool(arguments): + start_time = time.time() + if arguments['tool']=='enhance_reasoning': + supported_models = [MODEL_MAPPING[m] for m in ALL_TOOLS['enhance_reasoning']['model']] + assert arguments['model'] in supported_models,f"Model {arguments['model']} is not supported in enhance_reasoning. Support models: {supported_models}" + prompt = arguments['context_str'].strip()+'\n\n' + prompt += f"Question: {arguments['problem']}\nInstead of directly answering the question, please write additional python code that will give intermidiate results after execution. Wrap the code within ```python and ```. The code should be self-contained with all the import and initialization." + model_name = arguments['model'] + response = '' + if 'gpt-5' in model_name.lower(): + response = get_llm_response(model=model_name,messages=prompt,return_raw_response=True,temperature=1,max_length=40000) + elif 'qwen2.5-coder' in model_name.lower(): + response = get_llm_response(model=model_name,messages=prompt,return_raw_response=True,model_type='vllm',max_length=8000,temperature=0.2,model_config=arguments['vllm_model_configs'][model_name],model_config_path=arguments['vllm_model_configs']['vllm_model_config_path'],model_config_idx=arguments['eid']) + if isinstance(response,str): + response = '' + while not response: + try: + response = oss_client.chat.completions.create( + model="nvdev/qwen/qwen2.5-coder-32b-instruct", + messages=[{"role":"user","content":prompt}],temperature=0.2, + top_p=0.7, + max_tokens=8000, + ) + except Exception as qwen_error: + time.sleep(3) + if isinstance(response,str): + arguments['generated_code'] = '' + arguments['exec_result'] = '' + return arguments + try: + generated_code = response.choices[0].message.content.split('```python')[-1].split('```')[0] + except: + generated_code = '' + if generated_code=='': + arguments['generated_code'] = '' + arguments['exec_result'] = '' + return arguments + code_path = str(os.path.join(arguments['cur_output_dir'],f'exec_code_{arguments["id"]}.py')) + with open(code_path,'w') as f: + f.write(generated_code) + exec_result = '' + exec_start = time.time() + try: + exec_result = subprocess.run(['python', code_path], timeout=60, capture_output=True, text=True) + exec_time = time.time()-exec_start + exec_result = exec_result.stdout + with open(os.path.join(arguments['cur_output_dir'],f'exec_out_{arguments["id"]}.txt'),'w') as f: + f.write(exec_result) + except Exception as e: + pass + exec_time = time.time() - exec_start + arguments['generated_code'] = generated_code + arguments['exec_result'] = exec_result + return arguments + + elif arguments['tool']=='answer': + prompt = arguments['context_str'].strip()+'\n\nProblem:\n'+arguments['problem'] + response_str = '' + pred = '' + + if 'qwen3' in arguments['model'].lower(): + model_name = arguments['model'] + messages = [ + {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, + {"role": "user", "content": prompt} + ] + arguments['messages'] = messages + response = get_llm_response(model=model_name,messages=messages,return_raw_response=True,model_type='vllm',max_length=8000,temperature=0.2,model_config=arguments['vllm_model_configs'][model_name],model_config_path=arguments['vllm_model_configs']['vllm_model_config_path'],model_config_idx=arguments['eid']) + if isinstance(response,str): + arguments['response'] = '' + arguments['pred'] = '' + arguments['correctness'] = False + return arguments + response_str = response.choices[0].message.content + if not isinstance(response_str,str) or not '\\boxed{' in response_str: + pred = '' + else: + pred_components = response.choices[0].message.content.split('\\boxed{')[-1].split('}')[:-1] + pred = '}'.join(pred_components).strip() + elif 'qwen2.5-math' in arguments['model'].lower(): + model_name = arguments['model'] + messages = [ + {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."}, + {"role": "user", "content": prompt} + ] + arguments['messages'] = messages + response = get_llm_response(model=model_name,messages=messages,return_raw_response=True,model_type='vllm',max_length=2000,temperature=0.2,model_config=arguments['vllm_model_configs'][model_name],model_config_path=arguments['vllm_model_configs']['vllm_model_config_path'],model_config_idx=arguments['eid']) + if isinstance(response,str): + arguments['response'] = '' + arguments['pred'] = '' + arguments['correctness'] = False + return arguments + response_str = response.choices[0].message.content + if not isinstance(response_str,str) or not '\\boxed{' in response_str: + pred = '' + else: + pred_components = response.choices[0].message.content.split('\\boxed{')[-1].split('}')[:-1] + pred = '}'.join(pred_components).strip() + elif 'gpt-5' in arguments['model'].lower(): + model_name = arguments['model'] + prompt += ("\n\nTake a deep breath and think hard with high reasoning, wrap the thoughts within and , and wrap only the exact answer without any explanation within and ." + "Output using the following format:\n\n...\n\n\n...\n") + arguments['messages'] = prompt + response = get_llm_response(model=model_name,messages=prompt,return_raw_response=True,max_length=40000) + if isinstance(response,str): + arguments['response'] = '' + arguments['pred'] = '' + arguments['correctness'] = False + return arguments + response_str = response.choices[0].message.content + if isinstance(response_str,str): + pred = response.choices[0].message.content.split('')[-1].split('')[0].strip() + else: + pred = '' + elif 'llama-3.3' in arguments['model'].lower(): + model_name = arguments['model'] + prompt += "\nWrap the thinking process and explanation between and and wrap only the exact answer without any explanation within and ." + arguments['messages'] = prompt + response = get_llm_response(model=model_name,messages=prompt,return_raw_response=True,model_type='vllm',max_length=40000,temperature=0.2,model_config=arguments['vllm_model_configs'][model_name],model_config_path=arguments['vllm_model_configs']['vllm_model_config_path'],model_config_idx=arguments['eid']) + if isinstance(response,str): + response = '' + while not response: + try: + response = client.chat.completions.create( + model="nvdev/meta/llama-3.3-70b-instruct", + messages=[{"role":"user","content":prompt}],temperature=0.2, + top_p=0.7, + max_tokens=40000, + ) + except Exception as llama_error: + time.sleep(3) + if isinstance(response,str): + arguments['response'] = '' + arguments['pred'] = '' + arguments['correctness'] = False + return arguments + response_str = response.choices[0].message.content + if isinstance(response_str,str): + pred = response.choices[0].message.content.split('')[-1].split('')[0].strip() + else: + pred = '' + + if pred.strip()=='' or len(pred.split(' '))>500: + correctness = False + elif pred.strip().lower()==arguments['answer'].strip().lower(): + correctness = True + else: + eval_prompt = (f"Question: {arguments['problem']}\n\n" + f"Student answer: {pred}\n\n" + f"Reference answer: {arguments['answer']}\n\n" + "Assume that the reference answer is correct. Output True if the student answer matches the reference answer. Output False if the student answer does not match the reference answer.") + eval_response = get_llm_response(model='gpt-5',messages=eval_prompt,temperature=1) + eval_result = eval_response.split('')[-1].split('')[0] + if eval_result.lower()=='true': + correctness = True + else: + correctness = False + arguments['response'] = response_str + arguments['pred'] = pred + arguments['correctness'] = correctness + return arguments + + elif arguments['tool']=='search': + contents = [] + prompt = arguments['context_str'].strip()+'\n\n' + prompt += f"Question: {arguments['problem']}\nInstead of directly answering the question, please write a query to search for a piece of relevant and missing information. The query should be a few key words about the information to search or a short sentence. Wrap the query within and ." + cur_query_writer = arguments['model'] + query_to_call = None + if 'gpt-5' in cur_query_writer.lower(): + response = get_llm_response(model=cur_query_writer,messages=prompt,return_raw_response=True,temperature=1,max_length=40000) + if isinstance(response,str) or not response: + query_to_call = arguments['problem'] + else: + query_to_call = response.choices[0].message.content.split('')[-1].split('')[0] + elif 'qwen3' in cur_query_writer.lower(): + response = get_llm_response(model=cur_query_writer,messages=prompt,return_raw_response=True,model_type='vllm',max_length=8000,temperature=0.2,model_config=arguments['vllm_model_configs'][cur_query_writer],model_config_path=arguments['vllm_model_configs']['vllm_model_config_path'],model_config_idx=arguments['eid']) + if isinstance(response,str): + query_to_call = arguments['problem'] + else: + query_to_call = response.choices[0].message.content.split('')[-1].split('')[0] + if query_to_call is None or len(query_to_call)<5: + pass + else: + assert len(query_to_call)>5,f"{query_to_call}" + payload = { + "queries": [query_to_call[:390]], + "topk": 50, + "return_scores": True, + "eid": arguments['id'] + } + results = None + all_vllm_model_configs = arguments['vllm_model_configs'] + search_try_count = 0 + while not results: + search_try_count += 1 + try: + cur_model_config = random.choice(all_vllm_model_configs['retrieval']) + results = requests.post(f'http://{cur_model_config["ip_addr"]}:{cur_model_config["port"]}/retrieve', json=payload).json() + except Exception as search_error: + time.sleep(3) + if results: + for r in results[0]: + if 'content' in r['document']: + contents.append(r['document']['content']) + elif 'contents' in r['document']: + contents.append(r['document']['contents']) + arguments['query'] = query_to_call + arguments['search_results_data'] = contents + if 'tokenizer' in arguments: + arguments.pop('tokenizer') + return arguments + +import asyncio +import contextlib +from concurrent.futures import ThreadPoolExecutor +from typing import Iterable, Tuple, Any, Callable + +# task_list is an iterable of (func, arg) pairs +async def run_all( + task_list: Iterable[Tuple[Callable[[Any], Any], Any]], + concurrency: int = 2, + progress: bool = False, + return_exceptions: bool = False, +): + loop = asyncio.get_running_loop() + sem = asyncio.Semaphore(concurrency) + + # create the executor sized to your concurrency gate + with ThreadPoolExecutor(max_workers=concurrency) as executor: + # wrap each task so it obeys the semaphore + async def run_one(idx: int, func: Callable, arg: Any): + async with sem: + if asyncio.iscoroutinefunction(func): + res = await func(arg) + else: + res = await loop.run_in_executor(executor, func, arg) + return idx, res, None + + task_list = list(task_list) + tasks = [asyncio.create_task(run_one(i, f, a)) + for i, (f, a) in enumerate(task_list)] + + results = [None] * len(tasks) + + if progress: + from tqdm import tqdm + pbar = tqdm(total=len(tasks)) + else: + pbar = None + + try: + # update progress as tasks complete + for fut in asyncio.as_completed(tasks): + idx, res, err = await fut + if err is None: + results[idx] = res + else: + if return_exceptions: + results[idx] = err + else: + # cancel remaining, then re-raise the first error + for t in tasks: + t.cancel() + with contextlib.suppress(Exception): + await asyncio.gather(*tasks, return_exceptions=True) + raise err + if pbar: + pbar.update(1) + finally: + if pbar: + pbar.close() + + return results + +def run_single(e): + if os.path.isfile(os.path.join(my_output_dir,f"{e['id']}.json")): + return + doc_list = [] + code_list = [] + attempt_list = [] + exp_start_time = time.time() + problem = e['question'] + user_problem = problem + answer = e['answer'] + all_tool_calls = [] + final_correct = False + final_answer_model = None + final_pred = '' + all_tool_responses = {} + all_message_responses = {} + used_tools = [] + for step in range(MAX_ROUNDS): + cur_output_dir = os.path.join(my_output_dir,f"step_{step}") + if not os.path.isdir(os.path.join(cur_output_dir,'tool_return')): + try: + os.makedirs(os.path.join(cur_output_dir,'tool_return')) + except: + pass + tools = [] + for t in raw_tools: + tools.append(t) + doc_str = '' + for doc_idx, doc in enumerate(doc_list): + doc_str += f"Doc {doc_idx+1}: {doc[:1200]} ...\n\n" + code_str = '' + for code_idx, code_piece in enumerate(code_list): + code_str += f"```python\n{code_piece['code']}\n```\n\n```output\n{code_piece['output']}\n```\n\n" + attempt_str = '' + for attempt_idx, attempt in enumerate(attempt_list): + attempt_str += f"Attempt{attempt_idx+1} answer by {attempt['model']}: {attempt['answer']}\n" + str_cut = cut_seq(seq=attempt_str,l=8000) + attempt_str = str_cut['string_after_cut'] + if not attempt_str.startswith('Attempt') and len(attempt_str)>0: + attempt_str = 'Attempt answer: '+attempt_str + str_cut = cut_seq(seq=code_str+attempt_str,l=12000) + code_attempt_str = str_cut['string_after_cut'] + code_attempt_str_len = str_cut['effective_length'] + if not code_attempt_str.startswith('```') and len(code_attempt_str)>0: + code_attempt_str = '```\n'+code_attempt_str + doc_flag = False + problem_length = len(tokenizer(problem)['input_ids']) + if code_attempt_str_len<27000-problem_length: + if code_attempt_str: + context_str = cut_seq(seq=doc_str+"\npython code and execution outputs:\n"+code_attempt_str,l=27000-problem_length) + else: + context_str = cut_seq(seq=doc_str,l=27000-problem_length) + context_str = context_str['string_after_cut'] + if len(doc_str)>0: + doc_flag = True + context_str = 'Documents:\n'+context_str + else: + context_str = code_attempt_str + + removed_tool = None + if len(used_tools)>1 and used_tools[-1]==used_tools[-2]: + updated_tools = [] + removed_tool = used_tools[-1] + for t in tools: + if t['function']['name']!=used_tools[-1]: + updated_tools.append(t) + else: + updated_tools = tools + cur_tool_set = [t['function']['name'] for t in updated_tools] + chat = [ + {"role": "system", "content": "You are good at using tools."}, + {"role": "user", "content": f"Problem: {problem}\n\n{context_str}\n\nChoose an appropriate tool.'"} + ] + response = get_llm_response(model=MODEL_NAME,messages=chat,return_raw_response=True,model_type='vllm',model_config=vllm_model_configs[MODEL_NAME],temperature=1,max_length=12000,tools=tools,model_config_path=vllm_model_configs['vllm_model_config_path'],model_config_idx=e['eid']) + cache_idx = 0 + while os.path.isfile(f"input_output/{cache_idx}.json"): + cache_idx += 1 + if isinstance(response,str): + continue + tool_calls = response.choices[0].message.tool_calls + cache_tool_calls = [] + for one_tool_call in tool_calls: + tool_name = one_tool_call.function.name + try: + tool_arguments = json.loads(one_tool_call.function.arguments) + except: + pass + cache_tool_calls.append({ + 'tool_name': tool_name, + 'tool_arguments': tool_arguments + }) + message_dict = { + 'content': response.choices[0].message.content, + 'tool_calls': cache_tool_calls + } + if len(tool_calls)==0: + all_tool_calls.append(f'342 invalid tool calls {tool_calls}') + continue + tool_call_list = [] + cur_tool_calls = [] + processed_tools = set() + for one_tool_call in tool_calls: + tool_name = one_tool_call.function.name + try: + tool_arguments = json.loads(one_tool_call.function.arguments) + except: + pass + if not tool_name in ALL_TOOLS: + cur_tool_calls.append(f'350 invalid tool calls {tool_calls}') + continue + func_signature = ALL_TOOLS[tool_name] + valid_tool_call = True + for parameter_name,parameter_values in func_signature.items(): + if (not parameter_name in tool_arguments): + valid_tool_call = False + if (not tool_arguments[parameter_name] in parameter_values) and parameter_values!='any': + valid_tool_call = False + if not valid_tool_call: + cur_tool_calls.append(f'360 invalid tool calls {tool_calls}') + continue + + if tool_name in processed_tools: + continue + processed_tools.add(tool_name) + tool_call = { + 'name': tool_name, + 'arguments': tool_arguments + } + cur_tool_calls.append([tool_call]) + expert_model_to_call = MODEL_MAPPING[tool_arguments['model']] + + call_tool_argument = None + used_tools.append(tool_name) + if tool_name=='enhance_reasoning': + if 'qwen2.5-coder' in expert_model_to_call.lower(): + max_code_length = 16000 + max_context_length = 24000 + elif 'gpt-5' in expert_model_to_call.lower(): + max_code_length = 40000 + max_context_length = 120000 + doc_str = '' + for doc_idx, doc in enumerate(doc_list): + if 'qwen2.5-coder' in expert_model_to_call.lower(): + doc_str += f"Doc {doc_idx+1}: {doc[:1000]}\n\n" + else: + doc_str += f"Doc {doc_idx+1}: {doc}\n\n" + code_str = '' + for code_idx, code_piece in enumerate(code_list): + code_str += f"```python\n{code_piece['code']}\n```\n\n```output\n{code_piece['output']}\n```\n\n" + str_cut = cut_seq(seq=code_str,l=max_code_length) + code_str = str_cut['string_after_cut'] + code_str_len = str_cut['effective_length'] + if not code_str.startswith('```') and len(code_str)>0: + code_str = '```\n'+code_str + problem_len = len(tokenizer(user_problem)['input_ids']) + context_str = cut_seq(seq=doc_str+code_str,l=max_context_length-problem_len) + context_str = context_str['string_after_cut'] + if len(doc_str)>0: + context_str = 'Documents:\n'+context_str + call_tool_argument = { + 'tool': tool_name, + 'model': expert_model_to_call, + 'context_str': context_str, + 'vllm_model_configs': vllm_model_configs, + 'cur_output_dir': cur_output_dir, + 'problem': user_problem, + 'id': e['id'], + 'eid': e['eid'] + } + elif tool_call['name']=='answer': + if 'qwen2.5-math' in expert_model_to_call.lower(): + max_code_length = 1000 + max_context_length = 2000 + elif 'llama-3.3' in expert_model_to_call.lower(): + max_code_length = 10000 + max_context_length = 80000 + elif 'qwen3' in expert_model_to_call.lower(): + max_code_length = 12000 + max_context_length = 24000 + elif 'gpt-5' in expert_model_to_call.lower(): + max_code_length = 40000 + max_context_length = 120000 + doc_str = '' + for doc_idx, doc in enumerate(doc_list): + if 'gpt-5' in expert_model_to_call.lower() or 'llama' in expert_model_to_call.lower(): + doc_str += f"Doc {doc_idx+1}: {doc}\n\n" + else: + doc_str += f"Doc {doc_idx+1}: {doc[:1000]}\n\n" + code_str = '' + for code_idx, code_piece in enumerate(code_list): + code_str += f"```python\n{code_piece['code']}\n```\n\n```output\n{code_piece['output']}\n```\n\n" + str_cut = cut_seq(seq=code_str,l=max_code_length) + code_str = str_cut['string_after_cut'] + code_str_len = str_cut['effective_length'] + if not code_str.startswith('```') and len(code_str)>0: + code_str = '```\n'+code_str + problem_len = len(tokenizer(user_problem)['input_ids']) + context_str = cut_seq(seq=doc_str+code_str,l=max_context_length-problem_len) + context_str = context_str['string_after_cut'] + if len(doc_str)>0: + context_str = 'Documents:\n'+context_str + call_tool_argument = { + 'tool': tool_name, + 'model': expert_model_to_call, + 'context_str': context_str, + 'vllm_model_configs': vllm_model_configs, + 'cur_output_dir': cur_output_dir, + 'problem': user_problem, + 'answer': answer, + 'id': e['id'], + 'eid': e['eid'] + } + elif tool_call['name'] in ['search']: + if 'qwen3' in expert_model_to_call.lower(): + max_code_length = 12000 + max_context_length = 24000 + elif 'gpt-5' in expert_model_to_call.lower(): + max_code_length = 40000 + max_context_length = 120000 + doc_str = '' + for doc_idx, doc in enumerate(doc_list): + if 'gpt-5' in expert_model_to_call.lower(): + doc_str += f"Doc {doc_idx+1}: {doc}\n\n" + else: + doc_str += f"Doc {doc_idx+1}: {doc[:1000]}\n\n" + code_str = '' + for code_idx, code_piece in enumerate(code_list): + code_str += f"```python\n{code_piece['code']}\n```\n\n```output\n{code_piece['output']}\n```\n\n" + str_cut = cut_seq(seq=code_str,l=max_code_length) + code_str = str_cut['string_after_cut'] + code_str_len = str_cut['effective_length'] + if not code_str.startswith('```') and len(code_str)>0: + code_str = '```\n'+code_str + problem_len = len(tokenizer(user_problem)['input_ids']) + context_str = cut_seq(seq=doc_str+code_str,l=max_context_length-problem_len) + context_str = context_str['string_after_cut'] + if len(doc_str)>0: + context_str = 'Documents:\n'+context_str + call_tool_argument = { + 'tool': tool_name, + 'model': expert_model_to_call, + 'context_str': context_str, + 'vllm_model_configs': vllm_model_configs, + 'cur_output_dir': cur_output_dir, + 'problem': user_problem, + 'answer': answer, + 'id': e['id'], + 'eid': e['eid'] + } + tool_call_list.append([call_tool,call_tool_argument]) + break + all_tool_calls.append(cur_tool_calls) + + cache_argument = [] + for t in tool_call_list: + cache_argument.append(t[1]) + if len(tool_call_list)==0: + continue + cur_responses = asyncio.run(run_all(tool_call_list)) + all_tool_responses[f"turn_{step}_response"] = cur_responses + all_message_responses[f"turn_{step}_message"] = message_dict + finish_flag = False + for cur_response in cur_responses: + if cur_response['tool']=='enhance_reasoning': + if len(cur_response['exec_result'].strip())>0: + code_list.append({'code': cur_response['generated_code'], 'output': cur_response['exec_result']}) + elif cur_response['tool']=='answer': + final_correct = cur_response['correctness'] + final_answer_model = cur_response['model'] + final_pred = cur_response['pred'].strip() + finish_flag = True + break + elif cur_response['tool']=='search': + for one_doc in cur_response['search_results_data'][::-1]: + if not one_doc in doc_list: + doc_list.append(one_doc) + if finish_flag: + break + + return_dict = { + 'id': e['id'], + 'problem': problem, + 'all_tool_calls': all_tool_calls, + 'all_tool_responses': all_tool_responses, + 'answer': answer, + 'all_message_responses': all_message_responses, + 'correct': final_correct + } + with open(os.path.join(my_output_dir,f"{e['id']}.json"),'w') as f: + json.dump(return_dict,f,indent=2) + return return_dict + +if __name__=='__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', type=str) + parser.add_argument('--output_dir', type=str) + parser.add_argument('--model_config', type=str) + parser.add_argument('--max_rounds', type=int, default=50) + parser.add_argument('--model_type', type=str, default='Qwen/Qwen3-8B') + parser.add_argument('--example_path', type=str) + args = parser.parse_args() + + # global MODEL_NAME + MODEL_NAME = args.model_name + # global MODEL_TYPE + MODEL_TYPE = args.model_type + # global my_output_dir + my_output_dir = args.output_dir + # global MAX_ROUNDS + MAX_ROUNDS = args.max_rounds + if not os.path.isdir(os.path.join(my_output_dir,'answer_cache')): + os.makedirs(os.path.join(my_output_dir,'answer_cache')) + # global vllm_model_configs + with open(args.model_config) as f: + vllm_model_configs = json.load(f) + + with open(args.example_path) as f: + lines = f.readlines() + examples = [] + for eid,l in enumerate(lines): + raw_example = json.loads(l) + raw_example['eid'] = eid + examples.append([run_single,raw_example]) + + tool_call_results = asyncio.run(run_all(examples)) \ No newline at end of file