Skip to content

Chart QA #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions chartqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import argparse
import asyncio
import base64
from dataclasses import dataclass
from io import BytesIO
import pickle
from typing import Any
from datasets import load_dataset
from openai import AsyncOpenAI
import openai
from tqdm import tqdm
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import os

# Define the retry strategy
retry_strategy = retry(
stop=stop_after_attempt(5), # Stop after 5 attempts
wait=wait_exponential(multiplier=1, min=4, max=10), # Exponential backoff
retry=retry_if_exception_type(Exception), # Retry on any exception
)


# Define the fetch_responses function with retry strategy
@retry_strategy
async def fetch_responses(prompt, semaphore, index, model_name, output_dir, max_tokens):
output_file = os.path.join(output_dir, f"response_{index}.pkl")
if os.path.exists(output_file):
print(f"File {output_file} already exists, skipping.")
return

headers = {"Content-Type": "application/json"}
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{prompt.image}"},
},
{
"type": "text",
"text": (
f"""{prompt.question}
Analyze the image and question carefully, using step-by-step reasoning.
First, describe any image provided in detail. Then, present your reasoning. And finally your final answer in this format:
Final Answer: <answer>
where <answer> follows the following instructions:
- <answer> should be a single phrase or number.
- <answer> should not paraphrase or reformat the text in the image.
- If <answer> is a ratio, it should be a decimal value like 0.25 instead of 1:4.
- If the question is a Yes/No question, <answer> should be Yes/No.
- If <answer> is a number, it should not contain any units.
- If <answer> is a percentage, it should include a % sign.
- If <answer> is an entity, it should include the full label from the graph.
IMPORTANT: Remember, to end your answer with Final Answer: <answer>."""
),
},
],
}
]

data = {
"messages": messages,
"model": model_name,
"temperature": 0.0,
"max_tokens": max_tokens,
"stream": False,
}

url = "http://0.0.0.0:80/v1/chat/completions"

import httpx

async with semaphore:
async with httpx.AsyncClient(timeout=None) as client:
try:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
result = response.json()
with open(output_file, "wb") as f:
print(
f"Dumping response to {output_file} with prompt.answer={prompt.answer}"
)
pickle.dump(
(
result["choices"][0]["message"]["content"],
prompt.answer,
),
f,
)
except Exception as e:
print(f"GOT ERROR", e)
raise e


def get_client(provider):
if provider == "fw":
return AsyncOpenAI(base_url="https://api.fireworks.ai/inference/v1/")
elif provider == "tg":
return AsyncOpenAI(base_url="https://api.together.xyz/v1")
elif provider == "local_fw":
# return AsyncOpenAI(base_url="http://0.0.0.0:80/v1")
import httpx

return httpx.AsyncClient(timeout=None)
else:
raise ValueError(f"Invalid provider: {provider}")


@dataclass
class Entry:
type: str
question: str
answer: str
image: Any


def pil_image_to_base64(pil_img):
buffered = BytesIO()
pil_img.save(buffered, format="PNG") # Saved as PNG
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
return img_base64


# Define the main function
async def main(args):
ds = load_dataset("lmms-lab/ChartQA")

if args.num_examples is None:
args.num_examples = len(ds["test"])

prompts = []
count = 0
for type, question, answer, image in zip(
ds["test"]["type"],
ds["test"]["question"],
ds["test"]["answer"],
ds["test"]["image"],
):
prompts.append(Entry(type, question, answer, pil_image_to_base64(image)))
count += 1
if count >= args.num_examples:
break

os.makedirs(args.output_dir, exist_ok=True)

semaphore = asyncio.Semaphore(args.concurrency)
# client = get_client(args.provider)

tasks = []
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
tasks.append(
fetch_responses(
prompt,
semaphore,
idx,
args.model_name,
args.output_dir,
max_tokens=8192,
)
)

for future in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
):
await future


# Entry point for the script
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to run model with specified parameters."
)
parser.add_argument(
"--provider",
type=str,
required=True,
help="Provider name (e.g., fw, tg, or local_fw)",
)
parser.add_argument(
"--num-examples", type=int, default=None, help="Number of examples to process"
)
parser.add_argument("--concurrency", type=int, default=16)
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument(
"--output-dir", type=str, required=True, help="Directory to save responses"
)

args = parser.parse_args()
asyncio.run(main(args))
106 changes: 106 additions & 0 deletions chartqa_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pickle

predictions = []
references = []


def parse_answer(fw_answer):
try:
fw_answer = fw_answer.lower().split("final answer:")[1]
except:
return ""
return fw_answer.strip()


for i in range(2500):
file_path = f"/home/yingliu/llm_eval_meta/chartqa-do/response_{i}.pkl"
(fw_ans, ref) = pickle.load(open(file_path, "rb"))
predictions.append(parse_answer(fw_ans))
references.append(ref)


def check_equivalence(pred, ref):
pred = (
pred.lower()
.removeprefix("**")
.removesuffix("**")
.removesuffix("%")
.removesuffix(" billion")
.replace(",", "")
.strip()
)
ref = (
ref.lower()
.removesuffix("%")
.replace(",", "")
.removesuffix("]")
.removeprefix("[")
.strip()
)
try:
if abs(float(pred) - float(ref) * 100) < 0.000001:
return True
if abs(float(pred) * 100 - float(ref)) < 0.000001:
return True
if abs((float(pred) - float(ref))) / float(ref) < 0.010001:
return True

except:
pass
return pred == ref or pred == ref + "%"


# Test with the first example
equivalences = []
from tqdm import tqdm

correct = 0
for pred, ref in tqdm(
zip(predictions, references), total=len(predictions), desc="Checking equivalence"
):
if check_equivalence(pred, ref):
correct += 1
else:
print(pred, "|", ref)

num_hand_check_correct = 34
print(f"{(correct + num_hand_check_correct)/len(predictions):0.4f}")

"""
Hand verification of close things that regex cannot catch:

Disagreement democrat Democrat (scores 60 to 100)
Disagreement light beige gray
Disagreement dark blue Blue
Disagreement 5.25 trillion 5.25
Disagreement teal Teal Blue
Disagreement $24,688.3 24688.3
Disagreement italy, 22% [Italy , 22]
Disagreement $9,546.35 9545.35
Disagreement 2014-2016 [2014, 2016]
Disagreement neither No
Disagreement 2003-2004 [2003, 2004]
Disagreement tend to favor one side. Tend to favor one side
Disagreement 0.41 0.414285714
Disagreement 18-29 Ages 18-29
Disagreement 213k 213
Disagreement 151k 151
Disagreement austria and chile [Austria, Chile]
Disagreement facebook messenger Facebook Messenger*
Disagreement increased 35% Increased
Disagreement increases increasing
Disagreement staying alert and t... Staying alert and taking precautions
Disagreement germany vs. united states [Germany,United States]
Disagreement roku tv Robku TV
Disagreement estimated revenue in billion u.s. dollars. Estimated revenue in billion U.S. dollars
Disagreement blue light blue
Disagreement blue light blue
Disagreement blue light blue
Disagreement ** dark blue Navy blue
Disagreement ** rabbit Rabbit**
Disagreement metro / small bus Metro / small bus*
Disagreement casual consumers [Casual consumers**, Non-consumers]
Disagreement 42.47%, 54.91% [42.47, 54.91]
Disagreement 40-59 yrs 40-59 years
Disagreement -14% 14
"""