使用ChatGPT进行few-shot实验
1. 调查
- Zero-Shot and Few-shot learning method using ChatGPT on problem sets : [Few-Shot-ChatGPT(https://github.com/Showndarya/Few-Shot-ChatGPT)
- bullet: A Zero-Shot / Few-Shot Learning, LLM Based, text classification framework bullet
在少样本文本分类方面, 有人已经使用相同的模型进行了实验[1], 结果如下表:
在表中, ChatGPT的性能还是比较可观的, 但遗憾的是这篇论文并没有披露更多细节, 代码也未开源.
[1] Li, X. et al.: PDAMeta: Meta-Learning Framework with Progressive Data Augmentation for Few-Shot Text Classification. In: Proceedings of the 2024 Joint International Conference on Computational Linguistics, Language Resources and Evaluation (LREC-COLING 2024). pp. 12668–12678 ELRA and ICCL, Torino, Italia (2024).
2. 测试提示词
system_prompt =
"""
You are a professinal few-shot text classifier. You will be given {n}-way {k}-shot examples and a set of query texts.
Each example corresponds to one of the {n} classes.
Your task is to understand the class category from the supported texts, and then predict the correct class label for each query based on the provided examples of each class.
- Each class will be labeled as "Class i: [text k]", where i is the class number and each text is an example that help you understand the class i.
- Each query will be labeled as "Query Text: [text]" and needs to be classified into one of the {n} classes.
- Respond with only the predicted class labels (1, 2, ..., n), in the same order as the queries.
- Do not generate any additional text.
- Now, take a deep breath, focus on the classification task and complete it with high accuracy.
"""
user_prompt =
"""
Examples:
{Example}
Queries:
{Query}
"""
3. 实验代码
3.1 GPT分类器
class FSTCClassifier(BaseModel,extra="allow"):
n_support: int = 5
n_query: int = 5
n_class: int = 2
base_url: str = "https://api.openai.com/v1"
api_key: str = ""
context: str = """
You are a helpful AI assistant, skilled in classifying passages and texts into one of several categories based on examples.
"""
"""Initial prompt context."""
model: str = "gpt-3.5-turbo"
MAX_LENGTH: int = 16385
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.api_key == "":
self.api_key = os.environ.get("OPENAI_API_KEY")
if os.environ.get("OPENAI_API_BASE"):
self.base_url = os.environ.get("OPENAI_API_BASE")
# logger.info(f"API Key: {self.api_key}")
logger.info(f"Base URL: {self.base_url}")
logger.info(f"Model: {self.model}")
self.init_results()
def init_results(self):
self.results = {
"prediction": [],
"accuracy": [],
"f1": [],
"failed": 0,
"success": 0
}
@pydantic.computed_field()
@property
def api(self) -> int:
"""Instantiates and returns an OpenAI client object."""
client = OpenAI(api_key=self.api_key,base_url=self.base_url)
client.timeout = 20
return client
def validate_response(self, resp: str):
# split by "\n"
predictions_1 = resp.strip().split("\n")
# split by ","
predictions_2 = resp.strip().split(",")
predictions = predictions_1 if len(predictions_1) > len(predictions_2) else predictions_2
predictions = [p.strip() for p in predictions if p.strip().isdigit()] # Ensure numeric output
# minus 1
predictions = list(map(lambda x: str(int(str(x).strip())-1), predictions))
expected_count = self.n_class * self.n_query
if len(predictions) != expected_count:
logger.debug(f"Mismatch: Got {len(predictions)} predictions, expected {expected_count}")
return None
return predictions
# API请求失败时自动重试
@tenacity.retry(
retry=tenacity.retry_if_exception_type(
exception_types=(
openai._exceptions.RateLimitError,
openai._exceptions.APIConnectionError,
ReadTimeout,
)
),
wait=tenacity.wait_random(min=2, max=10),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
def predict(self, examples:dict, queries:dict, temperature:float = 0.1,
top_p: float = 0.9, max_tokens: int = 100):
prompt = FSTCPromptRequest(
n_support=self.n_support, n_query=self.n_query, n_class=self.n_class,
)
prompt.format(examples, queries)
if len(prompt) > self.MAX_LENGTH:
self.results["failed"] += 1
logger.error(f"Prompt length exceeds maximum limit of {self.MAX_LENGTH}")
return None
logger.info("Prompt:{}".format(str(prompt)))
response = self.api.chat.completions.create(
model=self.model,
# prompt=str(prompt), # Ensure
messages=[
{"role": "system", "content": prompt.get_system_prompt()},
{"role": "user", "content": prompt.user_prompt},
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
logger.debug(f"Response: {response}")
resp = response.choices[0].message.content.strip()
pred_labels = self.validate_response(resp)
if pred_labels is None:
self.results["failed"] += 1
return None
logger.debug(f"\nPredicted Labels: {pred_labels}\nExpected Labels: {prompt.labels}")
acc = accuracy_score(prompt.labels, pred_labels)
f1 = f1_score(prompt.labels, pred_labels, average="macro")
self.results['prediction'].append(pred_labels)
self.results['accuracy'].append(acc)
self.results['f1'].append(f1)
self.results["success"] += 1
logger.debug(f"Accuracy: {acc}, F1: {f1}")
return acc, f1
def batch_predict(self, examples:List[dict], queries:List[dict], temperature:float = 0.1,
top_p: float = 0.9, max_tokens: int = 100, test_count:int = 1):
self.init_results()
for i in range(test_count):
for example, query in tqdm.tqdm(zip(examples, queries),total=len(examples)):
result = self.predict(example, query, temperature, top_p, max_tokens)
if result is None:
logger.debug("A failure.")
logger.debug(f"Test Round {i} Results: {result}")
result = self.caculate_score()
return result
def caculate_score(self, results=None):
if not results:
results = self.results
return {
"accuracy": sum(results['accuracy'])/len(results['accuracy']),
"f1": sum(results['f1'])/len(results['f1']),
"success": results["success"],
"failed": results["failed"]
}
3.2 数据集
def load_dataset(n_class,n_support,n_query,dataset, cv,eval=False):
data_path = "data/{}/few_shot/{}/test.json".format(dataset, cv)
labels_path = "data/{}/few_shot/{}/labels.test.txt".format(dataset, cv)
fewshotdataset = FewShotDataset(
data_path=data_path,
labels_path=labels_path,
n_classes=n_class,n_support=n_support,n_query=n_query,
process_type=2, # test
eval=eval
)
return fewshotdataset
def get_examples(fewshotdataset, n_example, n_query):
examples = {}
queries = {}
# Fetch an episode from the dataset
episode = fewshotdataset.get_episode()
# Support examples (training examples)
support = episode['xs']
# Query examples (test examples)
query = episode['xq']
# Loop through support and query examples
for label_idx, support_samples in enumerate(support):
# For each label (first dimension), we gather the sentences from support set
examples[label_idx] = [sample['sentence'] for sample in support_samples[:n_example]]
for label_idx, query_samples in enumerate(query):
# For each label (first dimension), we gather the sentences from query set
queries[label_idx] = [sample['sentence'] for sample in query_samples[:n_query]]
return examples, queries
4. 实验
4.1 实验设置
n_test_episodes = 100 # 一个数据集请求100次
n_test_count = 3 # 每次请求GPT分类3次
temperature = 0.1
n_top = 0.9
max_token = 100
model_name = "gpt-3.5-turbo"
classfifier = FSTCClassifier(n_support=n_support, n_query=n_query, n_class=n_class, model=model_name)
# classfifier.predict(e,q)
results = classfifier.batch_predict(e_list,q_list,temperature,n_top,max_token,n_test_count)
results
4.2 实验输出
HuffPost实验结果:
INFO:root:Base URL: [https://aihubmix.com/v1](https://aihubmix.com/v1)
INFO:root:Model: gpt-3.5-turbo 0it [00:00, ?it/s]
INFO:root:Predicted Labels: ['2', '1', '2', '3', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.8, F1: 0.7333333333333333
1it [00:01, 1.50s/it]
INFO:root:Predicted Labels: ['2', '1', '1', '1', '0'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.2, F1: 0.1
2it [00:02, 1.46s/it]
INFO:root:Predicted Labels: ['0', '1', '1', '3', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.8, F1: 0.7333333333333333
3it [00:04, 1.42s/it]
INFO:root:Predicted Labels: ['2', '1', '3', '0', '4'] Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 0.4, F1: 0.4
4it [00:05, 1.46s/it]
INFO:root:Predicted Labels: ['0', '1', '2', '3', '4']
Expected Labels: ['0', '1', '2', '3', '4'] INFO:root:Accuracy: 1.0, F1: 1.0
INFO:root:Test Round 0 Results: {'accuracy': 0.64, 'f1': np.float64(0.5933333333333333), 'success': 5, 'failed': 0}
...
{'accuracy': 0.65, 'f1': np.float64(0.6), 'success': 20, 'failed': 0}
出乎意料的是, ChatGPT的gpt3.5模型在这套实验环境和提示词下并不能实现较好的性能.
从实验输出看, ChatGPT的表现并不稳定, 我没有办法来复现这种结果, 这套提示词已经是我参考现有项目, 重复修改多次实现的了.
除此之外, GPT也可能不能完成任务, 如输出的内容完全不符合要求, 比如在20News数据集上, GPT的失败比率达到了惊人的121/379.
Id | Creation Time | Tags | cv (last) | accuracy (last) | f1 (last) | failed (last) | success (last) | max_token (last) | n_class (last) | n_support (last) | n_query (last) | n_test_count (last) | n_test_episodes (last) | model_name | temperature (last) | top_p (last) | cvs | dataset |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
GPTFSTC-88 | 2024-09-24 16:06:00 | [“SUMMARY”] | 0.212233643171414 | 0.144973995908919 | 121 | 379 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 5 | 20News | |
GPTFSTC-87 | 2024-09-24 16:02:22 | [“CV”] | 5 | 0.201632519527256 | 0.134054008369798 | 30 | 70 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 20News | |
GPTFSTC-86 | 2024-09-24 15:59:07 | [“CV”] | 4 | 0.241230769230769 | 0.167721123321123 | 28 | 72 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 20News | |
GPTFSTC-85 | 2024-09-24 15:55:31 | [“CV”] | 3 | 0.196209150326797 | 0.130757936507937 | 19 | 81 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 20News | |
GPTFSTC-84 | 2024-09-24 15:51:44 | [“CV”] | 2 | 0.209166666666667 | 0.146027777777778 | 16 | 84 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 20News | |
GPTFSTC-83 | 2024-09-24 15:47:55 | [“CV”] | 1 | 0.212929110105581 | 0.146309133567957 | 28 | 72 | 100 | 5 | 5 | 1 | 2 | 50 | gpt-3.5-turbo | 0.1 | 0.9 | 20News | |
相比之下, 使用BERT预训练模型的实验结果如下: |
Id | Name | cv | Tags | best_acc (last) | Creation Time | lr | early_stop | n_query |
---|---|---|---|---|---|---|---|---|
CON-2600 | 0.832 | 0.8204 - 0.8231 | 0.8167 | 0.8333 | 0.8226 | 5 | [“D_GROW”,“PN_ATT”,“K5C5”,“MODEL”,“PN”,“LB1”,“CN”,“F1_KNN”,“TP”,“TP_V2”] | 0.852239975750446 | 2024-06-14 09:23:33 | 0.000001 | 3 | 5 |
CON-2599 | [24-06-14_08:47:19][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] | 4 | [“MODEL”,“K5C5”,“F1_KNN”,“CN”,“LB1”,“PN”,“D_GROW”,“TP_V2”,“TP”,“PN_ATT”] | 0.812279977023602 | 2024-06-14 08:51:08 | 0.000001 | 3 | 5 |
CON-2598 | [24-06-14_07:51:54][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] | 3 | [“MODEL”,“TP_V2”,“TP”,“PN_ATT”,“F1_KNN”,“D_GROW”,“K5C5”,“LB1”,“CN”,“PN”] | 0.804999976098538 | 2024-06-14 07:55:43 | 0.000001 | 3 | 5 |
CON-2597 | [24-06-14_07:14:52][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] | 2 | [“MODEL”,“LB1”,“F1_KNN”,“D_GROW”,“TP”,“CN”,“K5C5”,“PN”,“PN_ATT”,“TP_V2”] | 0.895439977049828 | 2024-06-14 07:18:41 | 0.000001 | 3 | 5 |
CON-2596 | [24-06-14_06:42:29][K5C5][LB:1][PN][MODEL][PN_ATT][D_GROW][TP_V2][F1_KNN] | 1 | [“PN”,“MODEL”,“TP_V2”,“CN”,“F1_KNN”,“K5C5”,“PN_ATT”,“LB1”,“D_GROW”,“TP”] | 0.795159977197647 | 2024-06-14 06:46:17 | 0.000001 | 3 | 5 |
我们大概可以这样总结几点:
- GPT的表现与提示词有非常大的关系, 并且提示工程也并不简单
- GPT的表现并不稳定, 如果要使用则需要设计一套完善的错误处理程序和输入处理程序
- GPT的表现是否正常也与输入内容有较大关系
然后经常会问到的一个问题:
自从OpenAI的GPT-3问世以来,自回归大模型的强势发展似乎让BERT等自编码语言模型显得有些过时. 特别是在文本分类任务中, GPT等自回归模型展现出了显著的优势.
在少样本学习领域, OpenAI的论文Language Models are Few-Shot Learners提供了强有力的支持, 使得自回归模型在这一领域占据了重要地位. 这似乎让传统的少样本研究失去了价值.
然而, 继续研究包括传统模型在内的少样本学习, 特别是少样本文本分类任务, 仍有其必要性. 那么, 元学习等方法相比GPT有哪些优势呢?
这个问题本质上是自回归模型与自编码模型的比较.
我们可以从这几方面进行比较:
- 定义
代表 | 主要任务 | 结构 | 常见使用场景 | |
---|---|---|---|---|
自编码模型 | BERT | 理解 | 双向Transformer | 文本分类 NER NSP |
自回归模型 | GPT | 生成 | 单向Transformer | 文本生成 翻译 |
- #TODO 待续
5. 未来
对目前的实验设置, 有以下几点不足:
- 样例顺序问题, GPT可能会察觉其中的规律
- 先验知识, 是否需要给GPT充足的先验知识
- 会话问题, 在每轮实验中, 是否需要在单次会话中进行多次询问, 这或许会让GPT更加熟悉这个任务
- 提示词问题, GPT的表现与提示词有非常大的关系, 未来需要再改善提示词