From fdf819951e569cfae4702e89fbf1b801a1ad5aa2 Mon Sep 17 00:00:00 2001 From: YLGH Date: Tue, 8 Apr 2025 22:54:36 +0000 Subject: [PATCH 1/3] do --- chartqa_check.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++ chartqa_verify.py | 106 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 chartqa_check.py create mode 100644 chartqa_verify.py diff --git a/chartqa_check.py b/chartqa_check.py new file mode 100644 index 0000000..39a9592 --- /dev/null +++ b/chartqa_check.py @@ -0,0 +1,119 @@ +import asyncio +import os +import pickle + +from tqdm import tqdm +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +import os + + +def parse_answer(fw_answer): + fw_answer = fw_answer.split("FINAL ANSWER:")[1] + return fw_answer.strip() + + +# Define the retry strategy +retry_strategy = retry( + stop=stop_after_attempt(5), # Stop after 5 attempts + wait=wait_exponential(multiplier=1, min=4, max=10), # Exponential backoff + retry=retry_if_exception_type(Exception), # Retry on any exception +) + + +@retry_strategy +async def fetch_responses( + client, + pred, + ref, + semaphore, +): + # Construct the prompt for ChatGPT + prompt = f""" + Are these two answers equivalent? + + They don't need to be an exact match, just close enough is correct. + Consider percentages (%) equivalent to their decimal form (e.g., 50% = 0.5). + + Please consider things correct even if it missing a unit. For example '13 years' is equivalent to '13'. + + Please consider things correct even if one is missing a %. For example '30%' is equivalent to '30'. Which is also equivalent to '0.3'. + + 1.6 million t. should match 1.6 + + Only reply with Yes or No. + + Answer 1: {pred} + Answer 2: {ref} + """ + + async with semaphore: + response = await client.chat.completions.create( + model="accounts/fireworks/models/deepseek-v3-0324", + messages=[ + {"role": "user", "content": prompt}, + ], + temperature=0.0, + max_tokens=128, + ) + return response.choices[0].message.content.strip() + + +async def main(): + references = [] + predictions = [] + for i in range(2500): + file_path = f"/home/yingliu/llm_eval_meta/fw_maverick_chartqa/response_{i}.pkl" + if not os.path.exists(file_path): + print(f"{i=}: file not found") + continue + ans = pickle.load(open(file_path, "rb")) + fw_raw_response = ans[0] + answer = ans[1] + references.append(answer) + predictions.append(parse_answer(fw_raw_response)) + + from openai import AsyncOpenAI + + tasks = [] + from tqdm import tqdm + + client = AsyncOpenAI( + base_url="https://api.fireworks.ai/inference/v1", + api_key="OvN1JEAYD7pAdN20djrZPZnxs0Ap7QxLlXVzGnAnHSW2FK1Q", + timeout=None, + ) + semaphore = asyncio.Semaphore(64) + for pred, ref in tqdm( + zip(predictions, references), + total=len(predictions), + desc="Checking equivalence", + ): + tasks.append(asyncio.create_task(fetch_responses(client, pred, ref, semaphore))) + + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" + ): + await future + + correct = 0 + incorrect = 0 + for idx, task in enumerate(tasks): + if task.result() == "Yes": + correct += 1 + elif task.result() == "No": + incorrect += 1 + print(f"Incorrect: {idx}") + print("Prediction:", predictions[idx]) + print("Reference:", references[idx]) + else: + print(f"Error: {task.result()}", idx) + print(f"% CORRECT = {correct / len(tasks)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/chartqa_verify.py b/chartqa_verify.py new file mode 100644 index 0000000..b7b4cf8 --- /dev/null +++ b/chartqa_verify.py @@ -0,0 +1,106 @@ +import pickle + +predictions = [] +references = [] + + +def parse_answer(fw_answer): + try: + fw_answer = fw_answer.lower().split("final answer:")[1] + except: + return "" + return fw_answer.strip() + + +for i in range(2500): + file_path = f"/home/yingliu/llm_eval_meta/chartqa-do/response_{i}.pkl" + (fw_ans, ref) = pickle.load(open(file_path, "rb")) + predictions.append(parse_answer(fw_ans)) + references.append(ref) + + +def check_equivalence(pred, ref): + pred = ( + pred.lower() + .removeprefix("**") + .removesuffix("**") + .removesuffix("%") + .removesuffix(" billion") + .replace(",", "") + .strip() + ) + ref = ( + ref.lower() + .removesuffix("%") + .replace(",", "") + .removesuffix("]") + .removeprefix("[") + .strip() + ) + try: + if abs(float(pred) - float(ref) * 100) < 0.000001: + return True + if abs(float(pred) * 100 - float(ref)) < 0.000001: + return True + if abs((float(pred) - float(ref))) / float(ref) < 0.010001: + return True + + except: + pass + return pred == ref or pred == ref + "%" + + +# Test with the first example +equivalences = [] +from tqdm import tqdm + +correct = 0 +for pred, ref in tqdm( + zip(predictions, references), total=len(predictions), desc="Checking equivalence" +): + if check_equivalence(pred, ref): + correct += 1 + else: + print(pred, "|", ref) + +num_hand_check_correct = 34 +print(f"{(correct + num_hand_check_correct)/len(predictions):0.4f}") + +""" +Hand verification of close things that regex cannot catch: + +Disagreement democrat Democrat (scores 60 to 100) +Disagreement light beige gray +Disagreement dark blue Blue +Disagreement 5.25 trillion 5.25 +Disagreement teal Teal Blue +Disagreement $24,688.3 24688.3 +Disagreement italy, 22% [Italy , 22] +Disagreement $9,546.35 9545.35 +Disagreement 2014-2016 [2014, 2016] +Disagreement neither No +Disagreement 2003-2004 [2003, 2004] +Disagreement tend to favor one side. Tend to favor one side +Disagreement 0.41 0.414285714 +Disagreement 18-29 Ages 18-29 +Disagreement 213k 213 +Disagreement 151k 151 +Disagreement austria and chile [Austria, Chile] +Disagreement facebook messenger Facebook Messenger* +Disagreement increased 35% Increased +Disagreement increases increasing +Disagreement staying alert and t... Staying alert and taking precautions +Disagreement germany vs. united states [Germany,United States] +Disagreement roku tv Robku TV +Disagreement estimated revenue in billion u.s. dollars. Estimated revenue in billion U.S. dollars +Disagreement blue light blue +Disagreement blue light blue +Disagreement blue light blue +Disagreement ** dark blue Navy blue +Disagreement ** rabbit Rabbit** +Disagreement metro / small bus Metro / small bus* +Disagreement casual consumers [Casual consumers**, Non-consumers] +Disagreement 42.47%, 54.91% [42.47, 54.91] +Disagreement 40-59 yrs 40-59 years +Disagreement -14% 14 +""" From 4c75a107b9da43db9374eaf48259ce782f73abda Mon Sep 17 00:00:00 2001 From: YLGH Date: Tue, 8 Apr 2025 22:55:27 +0000 Subject: [PATCH 2/3] ok --- chartqa_check.py | 119 ----------------------------------------------- 1 file changed, 119 deletions(-) delete mode 100644 chartqa_check.py diff --git a/chartqa_check.py b/chartqa_check.py deleted file mode 100644 index 39a9592..0000000 --- a/chartqa_check.py +++ /dev/null @@ -1,119 +0,0 @@ -import asyncio -import os -import pickle - -from tqdm import tqdm -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, -) -import os - - -def parse_answer(fw_answer): - fw_answer = fw_answer.split("FINAL ANSWER:")[1] - return fw_answer.strip() - - -# Define the retry strategy -retry_strategy = retry( - stop=stop_after_attempt(5), # Stop after 5 attempts - wait=wait_exponential(multiplier=1, min=4, max=10), # Exponential backoff - retry=retry_if_exception_type(Exception), # Retry on any exception -) - - -@retry_strategy -async def fetch_responses( - client, - pred, - ref, - semaphore, -): - # Construct the prompt for ChatGPT - prompt = f""" - Are these two answers equivalent? - - They don't need to be an exact match, just close enough is correct. - Consider percentages (%) equivalent to their decimal form (e.g., 50% = 0.5). - - Please consider things correct even if it missing a unit. For example '13 years' is equivalent to '13'. - - Please consider things correct even if one is missing a %. For example '30%' is equivalent to '30'. Which is also equivalent to '0.3'. - - 1.6 million t. should match 1.6 - - Only reply with Yes or No. - - Answer 1: {pred} - Answer 2: {ref} - """ - - async with semaphore: - response = await client.chat.completions.create( - model="accounts/fireworks/models/deepseek-v3-0324", - messages=[ - {"role": "user", "content": prompt}, - ], - temperature=0.0, - max_tokens=128, - ) - return response.choices[0].message.content.strip() - - -async def main(): - references = [] - predictions = [] - for i in range(2500): - file_path = f"/home/yingliu/llm_eval_meta/fw_maverick_chartqa/response_{i}.pkl" - if not os.path.exists(file_path): - print(f"{i=}: file not found") - continue - ans = pickle.load(open(file_path, "rb")) - fw_raw_response = ans[0] - answer = ans[1] - references.append(answer) - predictions.append(parse_answer(fw_raw_response)) - - from openai import AsyncOpenAI - - tasks = [] - from tqdm import tqdm - - client = AsyncOpenAI( - base_url="https://api.fireworks.ai/inference/v1", - api_key="OvN1JEAYD7pAdN20djrZPZnxs0Ap7QxLlXVzGnAnHSW2FK1Q", - timeout=None, - ) - semaphore = asyncio.Semaphore(64) - for pred, ref in tqdm( - zip(predictions, references), - total=len(predictions), - desc="Checking equivalence", - ): - tasks.append(asyncio.create_task(fetch_responses(client, pred, ref, semaphore))) - - for future in tqdm( - asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" - ): - await future - - correct = 0 - incorrect = 0 - for idx, task in enumerate(tasks): - if task.result() == "Yes": - correct += 1 - elif task.result() == "No": - incorrect += 1 - print(f"Incorrect: {idx}") - print("Prediction:", predictions[idx]) - print("Reference:", references[idx]) - else: - print(f"Error: {task.result()}", idx) - print(f"% CORRECT = {correct / len(tasks)}") - - -if __name__ == "__main__": - asyncio.run(main()) From d738a969601330b871070f897ff03df27d60e7a9 Mon Sep 17 00:00:00 2001 From: YLGH Date: Tue, 8 Apr 2025 22:55:50 +0000 Subject: [PATCH 3/3] add maker --- chartqa.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 chartqa.py diff --git a/chartqa.py b/chartqa.py new file mode 100644 index 0000000..50a54fb --- /dev/null +++ b/chartqa.py @@ -0,0 +1,196 @@ +import argparse +import asyncio +import base64 +from dataclasses import dataclass +from io import BytesIO +import pickle +from typing import Any +from datasets import load_dataset +from openai import AsyncOpenAI +import openai +from tqdm import tqdm +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +import os + +# Define the retry strategy +retry_strategy = retry( + stop=stop_after_attempt(5), # Stop after 5 attempts + wait=wait_exponential(multiplier=1, min=4, max=10), # Exponential backoff + retry=retry_if_exception_type(Exception), # Retry on any exception +) + + +# Define the fetch_responses function with retry strategy +@retry_strategy +async def fetch_responses(prompt, semaphore, index, model_name, output_dir, max_tokens): + output_file = os.path.join(output_dir, f"response_{index}.pkl") + if os.path.exists(output_file): + print(f"File {output_file} already exists, skipping.") + return + + headers = {"Content-Type": "application/json"} + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{prompt.image}"}, + }, + { + "type": "text", + "text": ( + f"""{prompt.question} + Analyze the image and question carefully, using step-by-step reasoning. + First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format: + Final Answer: + where follows the following instructions: + - should be a single phrase or number. + - should not paraphrase or reformat the text in the image. + - If is a ratio, it should be a decimal value like 0.25 instead of 1:4. + - If the question is a Yes/No question, should be Yes/No. + - If is a number, it should not contain any units. + - If is a percentage, it should include a % sign. + - If is an entity, it should include the full label from the graph. + IMPORTANT: Remember, to end your answer with Final Answer: .""" + ), + }, + ], + } + ] + + data = { + "messages": messages, + "model": model_name, + "temperature": 0.0, + "max_tokens": max_tokens, + "stream": False, + } + + url = "http://0.0.0.0:80/v1/chat/completions" + + import httpx + + async with semaphore: + async with httpx.AsyncClient(timeout=None) as client: + try: + response = await client.post(url, headers=headers, json=data) + response.raise_for_status() + result = response.json() + with open(output_file, "wb") as f: + print( + f"Dumping response to {output_file} with prompt.answer={prompt.answer}" + ) + pickle.dump( + ( + result["choices"][0]["message"]["content"], + prompt.answer, + ), + f, + ) + except Exception as e: + print(f"GOT ERROR", e) + raise e + + +def get_client(provider): + if provider == "fw": + return AsyncOpenAI(base_url="https://api.fireworks.ai/inference/v1/") + elif provider == "tg": + return AsyncOpenAI(base_url="https://api.together.xyz/v1") + elif provider == "local_fw": + # return AsyncOpenAI(base_url="http://0.0.0.0:80/v1") + import httpx + + return httpx.AsyncClient(timeout=None) + else: + raise ValueError(f"Invalid provider: {provider}") + + +@dataclass +class Entry: + type: str + question: str + answer: str + image: Any + + +def pil_image_to_base64(pil_img): + buffered = BytesIO() + pil_img.save(buffered, format="PNG") # Saved as PNG + img_bytes = buffered.getvalue() + img_base64 = base64.b64encode(img_bytes).decode("utf-8") + return img_base64 + + +# Define the main function +async def main(args): + ds = load_dataset("lmms-lab/ChartQA") + + if args.num_examples is None: + args.num_examples = len(ds["test"]) + + prompts = [] + count = 0 + for type, question, answer, image in zip( + ds["test"]["type"], + ds["test"]["question"], + ds["test"]["answer"], + ds["test"]["image"], + ): + prompts.append(Entry(type, question, answer, pil_image_to_base64(image))) + count += 1 + if count >= args.num_examples: + break + + os.makedirs(args.output_dir, exist_ok=True) + + semaphore = asyncio.Semaphore(args.concurrency) + # client = get_client(args.provider) + + tasks = [] + for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")): + tasks.append( + fetch_responses( + prompt, + semaphore, + idx, + args.model_name, + args.output_dir, + max_tokens=8192, + ) + ) + + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" + ): + await future + + +# Entry point for the script +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to run model with specified parameters." + ) + parser.add_argument( + "--provider", + type=str, + required=True, + help="Provider name (e.g., fw, tg, or local_fw)", + ) + parser.add_argument( + "--num-examples", type=int, default=None, help="Number of examples to process" + ) + parser.add_argument("--concurrency", type=int, default=16) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument( + "--output-dir", type=str, required=True, help="Directory to save responses" + ) + + args = parser.parse_args() + asyncio.run(main(args))