ChatGPT少样本实验初探

要回答自回归模型是否能取代自编码模型的问题...

使用ChatGPT进行few-shot实验

1. 调查

  • Zero-Shot and Few-shot learning method using ChatGPT on problem sets : [Few-Shot-ChatGPT(https://github.com/Showndarya/Few-Shot-ChatGPT)
  • bullet: A Zero-Shot / Few-Shot Learning, LLM Based, text classification framework bullet

在少样本文本分类方面, 有人已经使用相同的模型进行了实验[1], 结果如下表:

在表中, ChatGPT的性能还是比较可观的, 但遗憾的是这篇论文并没有披露更多细节, 代码也未开源.

[1] Li, X. et al.: PDAMeta: Meta-Learning Framework with Progressive Data Augmentation for Few-Shot Text Classification. In: Proceedings of the 2024 Joint International Conference on Computational Linguistics, Language Resources and Evaluation (LREC-COLING 2024). pp. 12668–12678 ELRA and ICCL, Torino, Italia (2024).

2. 测试提示词

system_prompt =
"""
        You are a professinal few-shot text classifier. You will be given {n}-way {k}-shot examples and a set of query texts.
        Each example corresponds to one of the {n} classes. 

        Your task is to understand the class category from the supported texts, and then predict the correct class label for each query based on the provided examples of each class.

        - Each class will be labeled as "Class i: [text k]", where i is the class number and each text is an example that help you understand the class i.
        - Each query will be labeled as "Query Text: [text]" and needs to be classified into one of the {n} classes.
        - Respond with only the predicted class labels (1, 2, ..., n), in the same order as the queries.
        - Do not generate any additional text.
        - Now, take a deep breath, focus on the classification task and complete it with high accuracy.
"""
user_prompt = 
"""
        Examples:
    {Example}

        Queries:
    {Query}
    """

3. 实验代码

3.1 GPT分类器

class FSTCClassifier(BaseModel,extra="allow"):
    n_support: int = 5
    n_query: int = 5
    n_class: int = 2
    base_url: str = "https://api.openai.com/v1"
    api_key: str = ""
    context: str = """
        You are a helpful AI assistant, skilled in classifying passages and texts into one of several categories based on examples.
    """
    """Initial prompt context."""
    model: str = "gpt-3.5-turbo"
    MAX_LENGTH: int = 16385
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if self.api_key == "":
            self.api_key = os.environ.get("OPENAI_API_KEY")
        if os.environ.get("OPENAI_API_BASE"):
            self.base_url = os.environ.get("OPENAI_API_BASE")
        # logger.info(f"API Key: {self.api_key}")
        logger.info(f"Base URL: {self.base_url}")
        logger.info(f"Model: {self.model}")
        self.init_results()
    def init_results(self):
        self.results = {
            "prediction": [],
            "accuracy": [],
            "f1": [],
            "failed": 0,
            "success": 0
        }
        
    @pydantic.computed_field()
    @property
    def api(self) -> int:
        """Instantiates and returns an OpenAI client object."""
        client = OpenAI(api_key=self.api_key,base_url=self.base_url)
        client.timeout = 20

        return client


    def validate_response(self, resp: str):
        # split by "\n"
        predictions_1 = resp.strip().split("\n")
        # split by ","
        predictions_2 = resp.strip().split(",")
        predictions = predictions_1 if len(predictions_1) > len(predictions_2) else predictions_2 
        
        predictions = [p.strip() for p in predictions if p.strip().isdigit()]  # Ensure numeric output
        # minus 1
        predictions = list(map(lambda x: str(int(str(x).strip())-1), predictions))
        expected_count = self.n_class * self.n_query
        if len(predictions) != expected_count:
            logger.debug(f"Mismatch: Got {len(predictions)} predictions, expected {expected_count}")
            return None
        return predictions
	
	# API请求失败时自动重试
    @tenacity.retry(
        retry=tenacity.retry_if_exception_type(
            exception_types=(
                openai._exceptions.RateLimitError,
                openai._exceptions.APIConnectionError,
                ReadTimeout,
            )
        ),
        wait=tenacity.wait_random(min=2, max=10),
        stop=tenacity.stop_after_attempt(5),
        reraise=True,
    )
    def predict(self, examples:dict, queries:dict, temperature:float = 0.1,
                 top_p: float = 0.9, max_tokens: int = 100):

        prompt = FSTCPromptRequest(
            n_support=self.n_support, n_query=self.n_query, n_class=self.n_class,
            )
        prompt.format(examples, queries)
        if len(prompt) > self.MAX_LENGTH:
            self.results["failed"] += 1
            logger.error(f"Prompt length exceeds maximum limit of {self.MAX_LENGTH}")
            return None
        logger.info("Prompt:{}".format(str(prompt)))
        response = self.api.chat.completions.create(
            model=self.model,
            # prompt=str(prompt),  # Ensure
            messages=[
                {"role": "system", "content": prompt.get_system_prompt()},
                {"role": "user", "content": prompt.user_prompt},
            ],
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        logger.debug(f"Response: {response}")
        
        resp = response.choices[0].message.content.strip()
        
        pred_labels = self.validate_response(resp)
        if pred_labels is None:
            self.results["failed"] += 1
            return None
        logger.debug(f"\nPredicted Labels: {pred_labels}\nExpected Labels:   {prompt.labels}")
        
        acc = accuracy_score(prompt.labels, pred_labels)
        f1 = f1_score(prompt.labels, pred_labels, average="macro")
        
        self.results['prediction'].append(pred_labels)
        self.results['accuracy'].append(acc)
        self.results['f1'].append(f1)
        self.results["success"] += 1
        logger.debug(f"Accuracy: {acc}, F1: {f1}")
        return acc, f1
    
    def batch_predict(self, examples:List[dict], queries:List[dict], temperature:float = 0.1,
                 top_p: float = 0.9, max_tokens: int = 100, test_count:int = 1):
        self.init_results()
        for i in range(test_count):
            for example, query in tqdm.tqdm(zip(examples, queries),total=len(examples)):

                result = self.predict(example, query, temperature, top_p, max_tokens)
                if result is None:
                    logger.debug("A failure.")
                    
            logger.debug(f"Test Round {i} Results: {result}")
        result = self.caculate_score()
        return result
    
    def caculate_score(self, results=None):
        if not results:
            results = self.results
        return {
            "accuracy": sum(results['accuracy'])/len(results['accuracy']),
            "f1": sum(results['f1'])/len(results['f1']),
            "success": results["success"],
            "failed": results["failed"]
        }

3.2 数据集

def load_dataset(n_class,n_support,n_query,dataset, cv,eval=False):
    data_path = "data/{}/few_shot/{}/test.json".format(dataset, cv)
    labels_path = "data/{}/few_shot/{}/labels.test.txt".format(dataset, cv)
    fewshotdataset = FewShotDataset(
        data_path=data_path,
        labels_path=labels_path,
        n_classes=n_class,n_support=n_support,n_query=n_query,
        process_type=2, # test
        eval=eval
    )
    return fewshotdataset
def get_examples(fewshotdataset, n_example, n_query):
    examples = {}
    queries = {}
    
    # Fetch an episode from the dataset
    episode = fewshotdataset.get_episode()
    
    # Support examples (training examples)
    support = episode['xs']
    
    # Query examples (test examples)
    query = episode['xq']
    
    # Loop through support and query examples
    for label_idx, support_samples in enumerate(support):
        # For each label (first dimension), we gather the sentences from support set
        examples[label_idx] = [sample['sentence'] for sample in support_samples[:n_example]]
        
    for label_idx, query_samples in enumerate(query):
        # For each label (first dimension), we gather the sentences from query set
        queries[label_idx] = [sample['sentence'] for sample in query_samples[:n_query]]
    
    return examples, queries

4. 实验

4.1 实验设置
n_test_episodes = 100 # 一个数据集请求100次
n_test_count = 3 # 每次请求GPT分类3次
temperature = 0.1
n_top = 0.9
max_token = 100
model_name = "gpt-3.5-turbo"

classfifier = FSTCClassifier(n_support=n_support, n_query=n_query, n_class=n_class, model=model_name)
# classfifier.predict(e,q)
results = classfifier.batch_predict(e_list,q_list,temperature,n_top,max_token,n_test_count)
results
4.2 实验输出

HuffPost实验结果:

INFO:root:Base URL: [https://aihubmix.com/v1](https://aihubmix.com/v1) 
INFO:root:Model: gpt-3.5-turbo 0it [00:00, ?it/s]
INFO:root:Predicted Labels: ['2', '1', '2', '3', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.8, F1: 0.7333333333333333 
1it [00:01, 1.50s/it]
INFO:root:Predicted Labels: ['2', '1', '1', '1', '0'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.2, F1: 0.1
2it [00:02, 1.46s/it]
INFO:root:Predicted Labels: ['0', '1', '1', '3', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.8, F1: 0.7333333333333333 
3it [00:04, 1.42s/it]
INFO:root:Predicted Labels: ['2', '1', '3', '0', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.4, F1: 0.4 
4it [00:05, 1.46s/it]
INFO:root:Predicted Labels: ['0', '1', '2', '3', '4'] 
Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 1.0, F1: 1.0
INFO:root:Test Round 0 Results: {'accuracy': 0.64, 'f1': np.float64(0.5933333333333333), 'success': 5, 'failed': 0}
...
{'accuracy': 0.65, 'f1': np.float64(0.6), 'success': 20, 'failed': 0}

出乎意料的是, ChatGPT的gpt3.5模型在这套实验环境和提示词下并不能实现较好的性能.

从实验输出看, ChatGPT的表现并不稳定, 我没有办法来复现这种结果, 这套提示词已经是我参考现有项目, 重复修改多次实现的了.

除此之外, GPT也可能不能完成任务, 如输出的内容完全不符合要求, 比如在20News数据集上, GPT的失败比率达到了惊人的121/379.

Id Creation Time Tags cv (last) accuracy (last) f1 (last) failed (last) success (last) max_token (last) n_class (last) n_support (last) n_query (last) n_test_count (last) n_test_episodes (last) model_name temperature (last) top_p (last) cvs dataset
GPTFSTC-88 2024-09-24 16:06:00 [“SUMMARY”] 0.212233643171414 0.144973995908919 121 379 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 5 20News
GPTFSTC-87 2024-09-24 16:02:22 [“CV”] 5 0.201632519527256 0.134054008369798 30 70 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 20News
GPTFSTC-86 2024-09-24 15:59:07 [“CV”] 4 0.241230769230769 0.167721123321123 28 72 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 20News
GPTFSTC-85 2024-09-24 15:55:31 [“CV”] 3 0.196209150326797 0.130757936507937 19 81 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 20News
GPTFSTC-84 2024-09-24 15:51:44 [“CV”] 2 0.209166666666667 0.146027777777778 16 84 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 20News
GPTFSTC-83 2024-09-24 15:47:55 [“CV”] 1 0.212929110105581 0.146309133567957 28 72 100 5 5 1 2 50 gpt-3.5-turbo 0.1 0.9 20News
相比之下, 使用BERT预训练模型的实验结果如下:
Id Name cv Tags best_acc (last) Creation Time lr early_stop n_query
CON-2600 0.832 | 0.8204 - 0.8231 | 0.8167 | 0.8333 | 0.8226 5 [“D_GROW”,“PN_ATT”,“K5C5”,“MODEL”,“PN”,“LB1”,“CN”,“F1_KNN”,“TP”,“TP_V2”] 0.852239975750446 2024-06-14 09:23:33 0.000001 3 5
CON-2599 [24-06-14_08:47:19][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] 4 [“MODEL”,“K5C5”,“F1_KNN”,“CN”,“LB1”,“PN”,“D_GROW”,“TP_V2”,“TP”,“PN_ATT”] 0.812279977023602 2024-06-14 08:51:08 0.000001 3 5
CON-2598 [24-06-14_07:51:54][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] 3 [“MODEL”,“TP_V2”,“TP”,“PN_ATT”,“F1_KNN”,“D_GROW”,“K5C5”,“LB1”,“CN”,“PN”] 0.804999976098538 2024-06-14 07:55:43 0.000001 3 5
CON-2597 [24-06-14_07:14:52][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] 2 [“MODEL”,“LB1”,“F1_KNN”,“D_GROW”,“TP”,“CN”,“K5C5”,“PN”,“PN_ATT”,“TP_V2”] 0.895439977049828 2024-06-14 07:18:41 0.000001 3 5
CON-2596 [24-06-14_06:42:29][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] 1 [“PN”,“MODEL”,“TP_V2”,“CN”,“F1_KNN”,“K5C5”,“PN_ATT”,“LB1”,“D_GROW”,“TP”] 0.795159977197647 2024-06-14 06:46:17 0.000001 3 5

我们大概可以这样总结几点:

  • GPT的表现与提示词有非常大的关系, 并且提示工程也并不简单
  • GPT的表现并不稳定, 如果要使用则需要设计一套完善的错误处理程序和输入处理程序
  • GPT的表现是否正常也与输入内容有较大关系

然后经常会问到的一个问题:

自从OpenAI的GPT-3问世以来,自回归大模型的强势发展似乎让BERT等自编码语言模型显得有些过时. 特别是在文本分类任务中, GPT等自回归模型展现出了显著的优势.

在少样本学习领域, OpenAI的论文Language Models are Few-Shot Learners提供了强有力的支持, 使得自回归模型在这一领域占据了重要地位. 这似乎让传统的少样本研究失去了价值.

然而, 继续研究包括传统模型在内的少样本学习, 特别是少样本文本分类任务, 仍有其必要性. 那么, 元学习等方法相比GPT有哪些优势呢?

这个问题本质上是自回归模型与自编码模型的比较.

我们可以从这几方面进行比较:

  1. 定义
代表 主要任务 结构 常见使用场景
自编码模型 BERT 理解 双向Transformer 文本分类 NER NSP
自回归模型 GPT 生成 单向Transformer 文本生成 翻译
  1. #TODO 待续

5. 未来

对目前的实验设置, 有以下几点不足:

  • 样例顺序问题, GPT可能会察觉其中的规律
  • 先验知识, 是否需要给GPT充足的先验知识
  • 会话问题, 在每轮实验中, 是否需要在单次会话中进行多次询问, 这或许会让GPT更加熟悉这个任务
  • 提示词问题, GPT的表现与提示词有非常大的关系, 未来需要再改善提示词