Finetuning a model to identify features of a product
Last updated
Last updated
Note: This example requires a GPU for finetuning. If you don't have a machine with GPUs handy, you can use the Colab version below with free T4 GPUs.
We will go through a small example here to create data with Ollama using Curator, finetune with Unsloth, and then evaluate it again using Curator.
Imagine you are a product wizard at a fictional product company called Azanom Inc., and want to highlight product features in the description of each product.
from IPython.display import HTML, display
import re
def display_product(
product_name,
description,
features,
image_url,
):
"""Displays a product give its product_name, features, and description"""
# Product description and features
def highlight_features(text, features):
# Sort features by length in descending order to handle overlapping matches
sorted_features = sorted(features, key=len, reverse=True)
# Create a copy of the text for highlighting
highlighted_text = text
# Replace each feature with its highlighted version
for feature in sorted_features:
pattern = re.compile(re.escape(feature), re.IGNORECASE)
highlighted_text = pattern.sub(
f'<span class="highlight">{feature}</span>',
highlighted_text
)
return highlighted_text
# Create HTML content with CSS styling
html_content = f"""
<style>
.product-container {{
max-width: 800px;
margin: 20px auto;
padding: 30px;
font-family: 'Segoe UI', Arial, sans-serif;
line-height: 1.6;
background: white;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}}
.product-title {{
color: #1d1d1f;
font-size: 28px;
margin-bottom: 20px;
text-align: center;
}}
.product-image {{
width: 100%;
max-width: 600px;
height: auto;
margin: 0 auto 30px;
display: block;
border-radius: 8px;
}}
.product-description {{
color: #333;
font-size: 16px;
margin-bottom: 20px;
}}
.highlight {{
background: linear-gradient(120deg, rgba(37, 99, 235, 0.1) 0%, rgba(37, 99, 235, 0.2) 100%);
border-radius: 4px;
padding: 2px 4px;
transition: background 0.3s ease;
}}
.highlight:hover {{
background: linear-gradient(120deg, rgba(37, 99, 235, 0.2) 0%, rgba(37, 99, 235, 0.3) 100%);
cursor: pointer;
}}
</style>
<div class="product-container">
<h1 class="product-title">{product_name}</h1>
"""
if image_url:
html_content += f'<img class="product-image" src="{image_url}" width="300px" alt="{product_name}">'
html_content += f"""<p class="product-description">
{highlight_features(description, features)}
</p>
</div>"""
display(HTML(html_content))
display_product(
product_name="Apple Airpods Pro",
description="The Apple AirPods Pro are a pair of wireless earbuds that are designed for comfort and convenience. They are lightweight in-ear earbuds and contoured for a comfortable fit, and they sit at an angle for easy access to the controls. The AirPods Pro also have a stem that is 33% shorter than the second generation AirPods, which makes them more compact and easier to store. The AirPods Pro also have a force sensor to easily control music and calls, and they have Spatial Audio with dynamic head tracking, which provides an immersive, three-dimensional listening experience.",
features=[
"lightweight in-ear earbuds",
"contoured design",
"sits at an angle for comfort",
"better direct audio to your ear",
"stem is 33% shorter than the second generation AirPods",
"force sensor to easily control music and calls",
"Spatial Audio with dynamic head tracking",
"immersive, three-dimensional listening experience"],
image_url="https://store.storeimages.cdn-apple.com/4982/as-images.apple.com/is/airpods-pro-2-hero-select-202409_FMT_WHH?wid=750&hei=556&fmt=jpeg&qlt=90&.v=1724041668836")
Given a product and its description, your first instinct is to use GPT-4o, to get the features given a product description. But you quickly realize that you don't need a jackhammer to nail this one and want to find a much cheaper and scalable alternative.
So let's try to train a 1B model by generating data from a 8B model. This data generation should cost $0. We can always use bigger models to generate higher-quality data.
Note that we have simplified this example for demonstration purposes.
# Install Python packages
!pip install bespokelabs-curator==0.1.15.post1
!pip install fuzzywuzzy datasets pydantic
!pip install unsloth
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install bitsandbytes triton unsloth_zoo
# Install ollama
!curl https://www.ollama.com/install.sh | OLLAMA_VERSION="0.5.4" sh
# We need llama3.1:8b for data generation and llama3.2:1b model for finetuning
!ollama pull llama3.1:8b
!ollama pull llama3.2:1b
Import the required library
# Make sure the following imports fine before proceeding, since this is needed for finetuning.
from unsloth import FastLanguageModel
from bespokelabs import curator
import os
import re
import json
import torch
import random
import numpy as np
from typing import List
from fuzzywuzzy import fuzz
from pydantic import BaseModel, Field
from datasets import Dataset, load_dataset
Our goal is to extract features from the product descriptions.
We will use Curator
to easily generate a dataset of products and their features. We seed the dataset with personas from PersonaHub and create products for each persona for diverse products. To make the data generation process easy for LLMs, we include the and tag in the output description. This way, we get high quality descriptions for given features.
# load the personas dataset
personas = load_dataset("proj-persona/PersonaHub", 'persona')
personas = personas['train'].take(100)
personas[0]
We can then create a ProductCurator
object and curate products using personas
class ProductCurator(curator.LLM):
# input prompt to the curator
def prompt(self, row):
return f"""Generate a product for the following persona: {row['persona']}
The product should be a product that is relevant to the persona. Give a name, description and features for the product.
Give upto 10 features for the product. Features should be relevant, extremely detailed and useful for the persona.
Then, generate a description for the product. Note that each feature should exactly be mentioned in the description. Do not add any other features or miss any features.
An example output is:
{{
"name": "Apple AirPods Pro",
"features": [
"lightweight in-ear earbuds",
"contoured for a comfortable fit",
"sits at an angle for comfort",
"better direct audio to your ear",
"stem is 33% shorter than the second generation AirPods",
"force sensor to easily control music and calls",
"Spatial Audio with dynamic head tracking",
"immersive, three-dimensional listening experience"
]
"description": "The Apple AirPods Pro are a pair of wireless earbuds that are designed for comfort and convenience. They are <feature>lightweight in-ear earbuds</feature> and <feature>contoured for a comfortable fit</feature>, and each airpod <feature>sits at an angle for comfort</feature>. The AirPods Pro also have a <feature>stem that is 33% shorter than the second generation AirPods</feature>, which makes them more compact and easier to store. The AirPods Pro also have <feature>a force sensor to easily control music and calls</feature>, and they have <feature>Spatial Audio with dynamic head tracking</feature>, which provides an <feature>immersive, three-dimensional listening experience</feature>.",
}}
Ensure each feature in the paragraph matches exactly as written in the description, including the <feature> and </feature> tags.
Make sure your output is a JSON and is in the following format. DO NOT OUTPUT ANYTHING ELSE. INCLUDE THE ```json tag in your response.
```json
{{
"name": "name of the product",
"features": [
"feature 1",
"feature 2",
"feature 3",
...
],
"description": "description of the product"
}}```
"""
def parse(self, row, response):
"""Parse the LLM response to extract the product name, features and description."""
default_response = { "name": "Apple AirPods Pro",
"features": [
"lightweight in-ear earbuds",
"contoured for a comfortable fit",
"sits at an angle for comfort",
"better direct audio to your ear",
"stem is 33% shorter than the second generation AirPods",
"force sensor to easily control music and calls",
"Spatial Audio with dynamic head tracking",
"immersive, three-dimensional listening experience"
],
"description": "The Apple AirPods Pro are a pair of wireless earbuds that are designed for comfort and convenience. They are <feature>lightweight in-ear earbuds</feature> and <feature>contoured for a comfortable fit</feature>, and each airpod <feature>sits at an angle for comfort</feature>. The AirPods Pro also have a <feature>stem that is 33% shorter than the second generation AirPods</feature>, which makes them more compact and easier to store. The AirPods Pro also have <feature>a force sensor to easily control music and calls</feature>, and they have <feature>Spatial Audio with dynamic head tracking</feature>, which provides an <feature>immersive, three-dimensional listening experience</feature>.",
}
if type(response) == type(''):
pattern = r"```json(.*?)```"
match_found = re.findall(pattern, response, re.DOTALL)
if match_found:
json_string = match_found[-1].strip()
try:
response = json.loads(json_string)
except:
response = default_response
else:
response = default_response
else:
response = response.dict()
try:
row['product'] = response['name']
# note that because the LLM isn't perfect, the features in the response may not be fully accurate
# row['original_features'] = response.features
# that's why, we parse the features from the output
pattern = r"<feature>(.*?)</feature>"
matches = re.findall(pattern, response['description'])
if matches:
row['features'] = matches
else:
# backup
row['features'] = response['features']
row['description'] = response['description'].replace('<feature>','').replace('</feature>','')
except:
return []
return row
product_curator = ProductCurator(
model_name="ollama/llama3.1:8b", # Ollama model identifier
backend_params={
"base_url": "http://localhost:11434",
"max_tokens_per_minute": 3000000,
"max_requests_per_minute": 10,
},
)
Next, let's create some products for the personas with ProductCurator
! This can take a while. You can use Together.ai or Deepinfra through Curator and LiteLLM to speed up this up.
# Generate products for the personas. This will take a while.
# You can use , for example, to speed this up.
products = product_curator(personas)
Here's an example of a generated product:
PERSONA: A Political Analyst specialized in El Salvador's political landscape.
PRODUCT: Salvadoria: El Salvador's Political Landscape Analyzer
DESCRIPTION: Salvadoria is a cutting-edge tool designed specifically for Political Analysts specializing in El Salvador's political landscape. It offers Advanced natural language processing for news articles and social media posts, allowing users to quickly analyze the tone, sentiment, and key themes of online discussions. The customizable keyword alert system enables analysts to track specific topics and hashtags in real-time, ensuring they stay up-to-date on the latest developments. Salvadoria also features an interactive map of El Salvador with election results, demographic data, and key infrastructure information, providing a comprehensive view of the country's political landscape. With access to a comprehensive database of past elections, including voter turnout, candidate performance, and electoral district boundaries, analysts can gain valuable insights into historical trends and patterns. The tool also includes an in-depth analysis of government spending, revenue, and budget allocation by department and agency, allowing users to identify areas of inefficiency or potential corruption. Salvadoria's real-time tracking of public opinion polls, surveys, and focus groups on various political issues keeps analysts informed about shifting public sentiment and policy preferences. Users can customize their dashboard with a range of visualizations and metrics using the customizable dashboard, while also exporting data in CSV format for further analysis or integration with other tools via the ability to export data in CSV format. Regular updates include new data, including special reports on election forecasts, economic indicators, and policy changes, which are integrated seamlessly through the integration with popular spreadsheet software.
FEATURES:
advanced natural language processing for news articles and social media posts
keyword alert system
interactive map of El Salvador with election results, demographic data, and key infrastructure information
comprehensive database of past elections
in-depth analysis of government spending, revenue, and budget allocation by department and agency
real-time tracking of public opinion polls, surveys, and focus groups
customizable dashboard
ability to export data in CSV format
integration with popular spreadsheet software
We can create an EvaluationLLM
object to evaluate the performance of our models
FEATURE_PROMPT = """
You are given a product's name, description and features. You will generate a list of features for the product.
An example input is:
{{
"name": "Apple AirPods Pro",
"description": "The Apple AirPods Pro are a pair of wireless earbuds that are designed for comfort and convenience. They are lightweight in-ear earbuds and contoured for a comfortable fit, and they sit at an angle for easy access to the controls. The AirPods Pro also have a stem that is 33% shorter than the second generation AirPods, which makes them more compact and easier to store. The AirPods Pro also have a force sensor to easily control music and calls, and they have Spatial Audio with dynamic head tracking, which provides an immersive, three-dimensional listening experience.",
}}
An example output is:
{{
"features": [
"lightweight in-ear earbuds",
"contoured for a comfortable fit",
"sit at an angle for easy access to the controls",
"stem is 33% shorter than the second generation AirPods",
"force sensor to easily control music and calls",
"Spatial Audio with dynamic head tracking",
"immersive, three-dimensional listening experience"
]
}}
Now, generate a list of features for the product. You should output all the features that are mentioned in the description exactly as they are written. You should not miss any features, or add any features that are not mentioned in the description.
Your output should be in this format.
```json{{features: ["feature 1","feature 2","feature 3",...]}}```
Product:
Name: {product_name}
Description: {product_description}
Output:
"""
class EvaluationLLM(curator.LLM):
# prompt for evaluation
def prompt(self, row):
return FEATURE_PROMPT.format(product_name=row['product'], product_description=row['description'])
# function to parse the LLM responses given by curator
# this function also contains the logic for evaluation of the LLM responses
def parse(self, row, response):
true_set = set(row['features'])
pred_set = set()
# Fuzzy matching threshold
SIMILARITY_THRESHOLD = 0.85
if type(response) != type(""):
predicted_features = response.features
else:
# string
pattern = r"```json(.*?)```"
match_found = re.findall(pattern, response, re.DOTALL)
if match_found:
json_string = match_found[-1].strip()
try:
out_dict = json.loads(json_string)
predicted_features = out_dict.get("features", [])
except:
print("Incorrect output format..")
predicted_features = []
else:
predicted_features = []
for pred_feature in predicted_features:
# Check if any true feature matches this predicted feature
best_match_score = 0
for true_feature in true_set:
similarity = fuzz.ratio(pred_feature.lower(), true_feature.lower()) / 100.0
best_match_score = max(best_match_score, similarity)
if best_match_score >= SIMILARITY_THRESHOLD:
pred_set.add(pred_feature)
# Calculate metrics
row['true_positives'] = len(pred_set) # Features that matched above threshold
row['false_positives'] = len(predicted_features) - row['true_positives'] # Predicted features that didn't match
row['false_negatives'] = len(true_set) - row['true_positives'] # True features that weren't matched
return row
We also set up some utilities to run the evaluation, calculate precision, recall, and F1 metrics, and tabulate them in a nice format.
# @title Utilities to calculate precision, recall, f1, run evaluations and tabulate results
def calculate_metrics(evaluation):
tp = sum(evaluation['true_positives'])
fp = sum(evaluation['false_positives'])
fn = sum(evaluation['false_negatives'])
micro_precision = tp / (tp + fp)
micro_recall = tp / (tp + fn)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall)
return {'precision': micro_precision,'recall': micro_recall, 'f1':micro_f1}
# Common function to run evaluation on different models
def run_evaluation(model_name, dataset):
evaluator = EvaluationLLM(
model_name=model_name,
backend_params={
"max_requests_per_minute":10000,
"max_tokens_per_minute":30000000
}
)
evaluation = evaluator(dataset)
metrics = calculate_metrics(evaluation)
return evaluation, metrics
# Tabulate eval results
from tabulate import tabulate
def tabulate_eval_results(model_and_metrics):
metrics_names = ['Precision', 'Recall', 'F1']
# Create table data
table_data = []
for model, metrics in model_and_metrics.items():
table_data.append([
model,
f"{metrics['precision']:.3f}",
f"{metrics['recall']:.3f}",
f"{metrics['f1']:.3f}"
])
# Print table
print(tabulate(table_data,
headers=['Model', 'Precision', 'Recall', 'F1'],
tablefmt='grid'))
In order to prevent bias from using the same model to generate train and eval data, we are not going to create a train and test split using the newly created data from Llama-3.1-8B. Instead, we will use all of it for training but generate eval data with a completely different LLM, gpt-4o-mini.
import getpass
os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
eval_product_curator = ProductCurator(
model_name="gpt-4o-mini"
)
train_dataset = products
test_dataset = eval_product_curator(personas.take(40))
evaluation_results_1b, metrics_1b = run_evaluation('ollama/llama3.2:1b', test_dataset)
evaluation_results_8b, metrics_8b = run_evaluation('ollama/llama3.1:8b', test_dataset)
tabulate_eval_results(model_and_metrics={"llama-3.2-1b": metrics_1b, "llama-3.1-8b": metrics_8b})
+--------------+-------------+----------+-------+
| Model | Precision | Recall | F1 |
+==============+=============+==========+=======+
| llama-3.2-1b | 0.668 | 0.394 | 0.496 |
+--------------+-------------+----------+-------+
| llama-3.1-8b | 0.726 | 0.753 | 0.739 |
+--------------+-------------+----------+-------+
We can see that Llama-3.1-8B is not able to extract the features as well as 8B (as expected). So, we will finetune it on the training set.
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
max_seq_length = 1024
load_in_4bit = True
dtype = None # for auto
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
peft_model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
# doing this so ollama creates a modelfile
tokenizer = get_chat_template(
tokenizer,
chat_template = "llama-3.1",
)
# Prepare a dataset for finetuning
def formatting_prompts_func(row):
texts = []
features = {'features': row['features']}
messages = [
{"role": "user", "content": FEATURE_PROMPT.format(product_name=row['product'], product_description=row['description'])},
{"role": "assistant", "content": f"```json{json.dumps(features)}```"},
]
text = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = False)
return {'text': text}
ft_dataset = train_dataset.map(formatting_prompts_func, batched = False,)
ft_dataset[0]
{'persona': "A Political Analyst specialized in El Salvador's political landscape.",
'product': 'SalvadorAlert',
'features': ["real-time updates on El Salvador's legislative calendar",
'in-depth analysis of proposed laws and bills',
'topics that matter most to them',
'data visualization of voting patterns and trends',
'comparative analysis of past and present legislative data',
'upcoming hearings and committee meetings',
"detailed information on El Salvador's presidential and congressional elections",
'analysis of public opinion polls and surveys',
'news from local and international sources',
'monitors social media activity of key politicians and influencers'],
'description': "The SalvadorAlert is a cutting-edge tool for political analysts specializing in El Salvador's political landscape. It provides real-time updates on El Salvador's legislative calendar, including key dates and events. The platform offers in-depth analysis of proposed laws and bills, allowing users to stay on top of the latest developments. Users can also customize their alert system to receive priority notifications on topics that matter most to them. The SalvadorAlert features data visualization of voting patterns and trends, providing a clear picture of the current political climate. Additionally, users can access comparative analysis of past and present legislative data to inform their research. The platform also includes alerts on upcoming hearings and committee meetings, ensuring users are always informed. SalvadorAlert provides detailed information on El Salvador's presidential and congressional elections, including key statistics and analysis. Users can also access analysis of public opinion polls and surveys to gauge public sentiment. Furthermore, the SalvadorAlert aggregates news from local and international sources related to El Salvador's politics, providing a comprehensive view of the news cycle. Finally, the platform monitors social media activity of key politicians and influencers, allowing users to stay ahead of the curve.",
'text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n\n You are given a product\'s name, description and features. You will generate a list of features for the product.\n\n An example input is:\n {\n "name": "Apple AirPods Pro",\n "description": "The Apple AirPods Pro are a pair of wireless earbuds that are designed for comfort and convenience. They are lightweight in-ear earbuds and contoured for a comfortable fit, and they sit at an angle for easy access to the controls. The AirPods Pro also have a stem that is 33% shorter than the second generation AirPods, which makes them more compact and easier to store. The AirPods Pro also have a force sensor to easily control music and calls, and they have Spatial Audio with dynamic head tracking, which provides an immersive, three-dimensional listening experience.",\n }\n\n An example output is:\n {\n "features": [\n "lightweight in-ear earbuds",\n "contoured for a comfortable fit",\n "sit at an angle for easy access to the controls",\n "stem is 33% shorter than the second generation AirPods",\n "force sensor to easily control music and calls",\n "Spatial Audio with dynamic head tracking",\n "immersive, three-dimensional listening experience"\n ]\n\n }\n\n Now, generate a list of features for the product. You should output all the features that are mentioned in the description exactly as they are written. You should not miss any features, or add any features that are not mentioned in the description.\n\n Your output should be in this format.\n ```json{features: ["feature 1","feature 2","feature 3",...]}```\n\n Product:\n Name: SalvadorAlert\n Description: The SalvadorAlert is a cutting-edge tool for political analysts specializing in El Salvador\'s political landscape. It provides real-time updates on El Salvador\'s legislative calendar, including key dates and events. The platform offers in-depth analysis of proposed laws and bills, allowing users to stay on top of the latest developments. Users can also customize their alert system to receive priority notifications on topics that matter most to them. The SalvadorAlert features data visualization of voting patterns and trends, providing a clear picture of the current political climate. Additionally, users can access comparative analysis of past and present legislative data to inform their research. The platform also includes alerts on upcoming hearings and committee meetings, ensuring users are always informed. SalvadorAlert provides detailed information on El Salvador\'s presidential and congressional elections, including key statistics and analysis. Users can also access analysis of public opinion polls and surveys to gauge public sentiment. Furthermore, the SalvadorAlert aggregates news from local and international sources related to El Salvador\'s politics, providing a comprehensive view of the news cycle. Finally, the platform monitors social media activity of key politicians and influencers, allowing users to stay ahead of the curve.\n Output:\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```json{"features": ["real-time updates on El Salvador\'s legislative calendar", "in-depth analysis of proposed laws and bills", "topics that matter most to them", "data visualization of voting patterns and trends", "comparative analysis of past and present legislative data", "upcoming hearings and committee meetings", "detailed information on El Salvador\'s presidential and congressional elections", "analysis of public opinion polls and surveys", "news from local and international sources", "monitors social media activity of key politicians and influencers"]}```<|eot_id|>'}
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from unsloth.chat_templates import train_on_responses_only
trainer = SFTTrainer(
model = peft_model,
tokenizer = tokenizer,
train_dataset = ft_dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
num_train_epochs = 1, # Set this for 1 full training run.
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)
trainer = train_on_responses_only(
trainer,
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
trainer_stats = trainer.train()
# save quantized model
peft_model.save_pretrained_gguf("llama_finetune", tokenizer)
# Unsloth automatically creates a Modelfile!
cat /content/llama_finetune/Modelfile
# Create an ollama model using the saved Modelfile
ollama create llama_finetune -f ./llama_finetune/Modelfile
# Verify that the finetuned model exists
ollama ls
Running the evaluation on the new finetuned model, we found that F1 for Llama-3.2-1B model jumped from 0.496 to 0.688, a significant improvement!
# evaluate the finetuned model
finetuned_llama_evaluation_results, finetuned_llama_metrics = run_evaluation('ollama/llama_finetune', test_dataset)
print(finetuned_llama_evaluation_results[0])
print(finetuned_llama_metrics)
tabulate_eval_results(model_and_metrics={"llama-3.2-1b": metrics_1b, "llama-3.1-8b": metrics_8b, "finetuned-llama-3.2-1b": finetuned_llama_metrics})
+------------------------+-------------+----------+-------+
| Model | Precision | Recall | F1 |
+========================+=============+==========+=======+
| llama-3.2-1b | 0.668 | 0.394 | 0.496 |
+------------------------+-------------+----------+-------+
| llama-3.1-8b | 0.726 | 0.753 | 0.739 |
+------------------------+-------------+----------+-------+
| finetuned-llama-3.2-1b | 0.796 | 0.606 | 0.688 |
+------------------------+-------------+----------+-------+
Just for fun, let try running our new finetuned model on a new example:
product_name = "Ryobi Circular Saw"
product_description = """ Expand your RYOBI 18V ONE+ System with the RYOBI 18V ONE+ Cordless Circular Saw. Make over 215 fast, clean cuts per charge on the ONE+ Cordless 5 1/2 in. Circular Saw with 4,700 RPM and the included 18T Carbide Tipped Blade. This saw is ideal for cross cuts in 2-by material with 1-11/16 in. maximum depth of cut. Bevel up to 50 degrees to complete a wide variety of cuts and with 1-3/16 in. depth of cut at 45 Degrees of bevel. Purchase the accessory vacuum dust adaptor (sold separately) to connect this saw to your wet/dry vac for quick and easy clean up. Best of all, it is part of the RYOBI ONE+ System of over 300 Cordless Products that all work on the same battery platform. This 18V ONE+ Cordless 5-1/2 in. Circular Saw is backed by the RYOBI 3-Year Manufacturer's Warranty. Battery and charger sold separately."""
class Extractor(curator.LLM):
def prompt(self, input):
return FEATURE_PROMPT.format(product_name=input['product'], product_description=input['description'])
extractor = Extractor(model_name='gpt-4o')
result = extractor(Dataset.from_list([{'product': product_name, 'description': product_description}]))['response'][0]
print(result)
def get_parsed_response(response):
if type(response) == type(''):
pattern = r"```json(.*?)```"
match_found = re.findall(pattern, response, re.DOTALL)
if match_found:
json_string = match_found[-1].strip()
try:
response = json.loads(json_string)
except:
raise ValueError("Failed to parse")
else:
raise ValueError("Failed to parse")
else:
response = response.dict()
return response
response = get_parsed_response(result)
display_product(
product_name,
product_description,
response['features'],
image_url=''
)
This is not bad for a quick start. In some cases, you will see that the LLM doesn't output exact text (which happens for even GPT-4o)!
Great next steps:
Increase the number of training examples.
Systematically evaluate the error types.
Run this in local machine and run curator-viewer
to visualize your data.
Create complex strategies for data curation (involving multiple curator.LLM stages).
Star https://github.com/bespokelabsai/curator/!