Using Gemini for batch inference

You can use Gemini for batch inference in Curator to generate synthetic data. In this example, we will generate reannotation of wildchat dataset, but the approach can be adapted for any data generation task.

Prerequisites

  • Python 3.10+

  • Curator: Install via pip install bespokelabs-curator

  • Gemini (Vertex AI): GCP account with Vertex AI enabled.

  • Google Cloud Bucket: Access to cloud storage.

Steps

1. Setup environment vars

export GOOGLE_CLOUD_REGION=us-central1
export GOOGLE_CLOUD_PROJECT=<projectname>
export GEMINI_BUCKET_NAME=<bucketname>
export GEMINI_API_KEY=<your_api_key>

2. ADC authentication

gcloud auth application-default login

3. Create a curator.LLM subclass

Create a class that inherits from curator.LLM. Implement two key methods:

  • prompt(): Generates the prompt for the LLM.

  • parse(): Processes the LLM's response into your desired format.

Here’s the implementation:

"""Example of reannotating the WildChat dataset using curator."""

import logging
from bespokelabs import curator

# To see more detail about how batches are being processed
logger = logging.getLogger("bespokelabs.curator")
logger.setLevel(logging.INFO)


class WildChatReannotator(curator.LLM):
    """A reannotator for the WildChat dataset."""

    def prompt(self, input: dict) -> str:
        """Extract the first message from a conversation to use as the prompt."""
        return input["conversation"][0]["content"]

    def parse(self, input: dict, response: str) -> dict:
        """Parse the model response along with the input to the model into the desired output format.."""
        instruction = input["conversation"][0]["content"]
        return {"instruction": instruction, "new_response": response}

3. Configure the Gemini Backend

distiller = WildChatReannotator(model_name="gemini-1.5-flash-002", 
                                backend="gemini", 
                                batch=True 
                                )

4. Generate Data

Generate the structured data and output the results as a pandas DataFrame:

from datasets import load_dataset
dataset = load_dataset("allenai/WildChat", split="train")
dataset = dataset.select(range(100))

distilled_dataset = distiller(dataset)
print(distilled_dataset)
print(distilled_dataset[0])

Example Output

Using the above example, the output might look like this:

instruction
new_response

Write a very long, elaborate, descriptive and ...

Scene: Omelette Apocalypse\n\n**INT. DINER...

what are you?

I am a large language model, trained by Google

Gemini Batch Configuration

Last updated