Skip to content

screenspot-v2评测分数很低 #98

@gaolongxi

Description

@gaolongxi

Hi~
我尝试在两个GUI Grounding的benchmark:ScreenSpot-v2和ScreenSpot上面测试AgentCPM-GUI的表现,使用的是fun_2_bbox.py中的prompt,目前测出来分数很低,screenspot-v2只有59%左右正确率,qwen等模型已经已经在90%左右。我不知道是否是我的推理代码有问题,请问官方在这两个benchmark上面测试过吗?

我的推理代码是:

import os
import re
import io
import json
import base64
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


class AgentCPM_GUI():

    def load_model(self, model_name_or_path="model/AgentCPM-GUI"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            dtype=torch.bfloat16,
            attn_implementation="flash_attention_2"
        ).to("cuda").eval()
        self.generation_params = {'do_sample': False, 'temperature': 0.0, 'use_cache': True, 'max_new_tokens': 2048}

        # --- system prompt ---
        self.sys_prompt = '''
你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的下一步操作是根据给定的GUI截图和图中某个组件的功能描述点击组件的中心位置。坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000
输入:屏幕截图,功能描述
输出:点击操作,以{\"POINT\":[...,...]}为格式,其中不能存在任何非坐标字符

# Rule
- 输出操作必须遵循Schema约束

# Schema
{
    "required": ["thought"]
}
'''

    def _resize(self, origin_img):
        resolution = origin_img.size
        w,h = resolution
        max_line_res = 1120
        if max_line_res is not None:
            max_line = max_line_res
            if h > max_line:
                w = int(w * max_line / h)
                h = max_line
            if w > max_line:
                h = int(h * max_line / w)
                w = max_line
        img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS)
        return img

    def set_generation_config(self, **kwargs):
        self.generation_params.update(**kwargs)

    def inference(self, instruction, image_path):

        assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."

        image = Image.open(image_path)
        image = self._resize(image)

        messages = [{
            "role": "user",
            "content": [
                f"屏幕上某一组件的功能描述:{instruction}\n当前屏幕截图:",
                image
            ]
        }]
        print("instruction:", instruction)

        output_text = self.model.chat(
            image=None,
            system_prompt=self.sys_prompt,
            msgs=messages,
            tokenizer=self.tokenizer,
            temperature=0.1,
            # do_sample=False
        )
        print("Raw response:", output_text)

        x_rel = y_rel = 0.0
        try:
            match = re.search(r'"POINT"\s*:\s*\[\s*(\d+(?:\.\d+)?),\s*(\d+(?:\.\d+)?)\s*\]', output_text)
            if match:
                x_rel = float(match.group(1))
                y_rel = float(match.group(2))
        except:
            print("Warning: No POINT parsed.")

        x_norm = x_rel / 1000.0
        y_norm = y_rel / 1000.0

        result = {
            "result": "positive",
            "format": "x1y1x2y2",
            "raw_response": output_text,
            "bbox": None,
            "point": [x_norm, y_norm],
        }

        return result

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions