batch_model_query#

async batch_model_query(*, prompt_info: list[dict], model: str, process_func: callable = None, process_func_params: dict = None, batch_size: int = 10, max_concurrent_requests: int = 5, rate_limit: tuple = (2, 1), results_path: str | Path = None, run_name: str | Path = None, model_params: dict = None, token: str = None, host: str = None, **kwargs) list[dict][source]#

Asynchronously process a batch of prompts using a language model client, saving intermediate and final results, and handling concurrent API requests with robust error handling and metadata tracking.

This function orchestrates the batch processing pipeline for querying language models. It divides the input prompts into batches, manages concurrency, applies optional post-processing, and saves results in both pickle and parquet formats. Intermediate results are saved after each batch for fault tolerance, and final results are consolidated at the end.

Parameters:
prompt_infolist of dict

List of dictionaries containing prompt information, each with system and user keys.

modelstr

The name or identifier of the language model to use.

process_funccallable, optional

Function to process each response. If None, raw message content is returned in the message field. Default is None.

process_func_paramsdict, optional

Parameters to pass to process_func. Ignored if process_func is None. Default is None.

batch_sizeint, optional

Number of prompts to process between intermittent saves. Default is 10.

max_concurrent_requestsint, optional

Maximum number of concurrent API requests allowed. Default is 5.

rate_limittuple, optional

Tuple of (max_requests, interval_seconds) specifying the maximum number of requests allowed per interval. Default is (2, 1), meaning 2 requests per 1 second. This is enforced using a rate limiter to avoid exceeding API provider quotas or triggering throttling.

results_pathstr or Path, optional

Path to save intermediate and final result files. If None, results are not saved. Default is None.

run_namestr or Path, optional

Name used to identify this batch run in saved files. Required if results_path is provided. Default is None.

model_paramsdict, optional

Dictionary of model parameters to override defaults. Supported keys: - max_tokens: Maximum tokens for completion (default: 2048) - temperature: Sampling temperature (default: 0) Default is None.

tokenstr, optional

API token for authentication. If not provided, will be loaded from environment or Databricks context.

hoststr, optional

API host URL. If not provided, will be loaded from environment or Databricks context.

Returns:
list of dict

A list of dictionaries, each containing the response and associated metadata for a prompt. Each dictionary includes:

  • message: Raw response content from the model.

  • processed_response: Processed content if process_func is provided.

  • chat: Full API response object (or None on error).

  • error: Error message if an exception occurred, None otherwise.

  • model: Model name used for generation.

  • temperature: Temperature setting used for generation.

  • max_tokens: Maximum tokens setting used.

  • prompt_tokens: Number of tokens in the prompt (if available).

  • completion_tokens: Number of tokens in the completion (if available).

  • total_tokens: Total number of tokens used (if available).

  • timing: Query execution time in seconds.

  • All original keys from the corresponding entry in prompt_info.

Note

  • When process_func is None, the function returns the raw message content in the message field.

  • After each batch, results (including chat objects) are saved as pickle files, and a version without the chat and message keys is saved as a parquet file.

  • Intermediate results are saved in a subdirectory named after run_name; final results are saved in results_path.

  • Intermediate files are deleted after successful completion.

Warning

  • Ensure that token and host are set, either via arguments, environment variables, env dotfile, or Databricks context.

  • If results_path is provided, run_name must also be specified.

Examples

The batch_model_query function allows you to send multiple prompts to a specified model asynchronously.

Running in a Notebook#

from dbutils_batch_query.model_query import batch_model_query, extract_json_items
from dbutils_batch_query.prompts import load_prompt

user_prompt = load_prompt(
    "path/to/your/prompt_template.md", input_text="Some text to analyze."
)
# Example prompt information (replace with your actual prompts)
prompt_info = [
    {
        "system": "You are an assistant that extracts key information. Respond in a JSON codeblock.",
        "user": user_prompt,
        "id": "query_1",  # Optional: Add identifiers or other metadata
    },
    {
        "system": "You are an assistant that summarizes text. Respond in a JSON codeblock.",
        "user": load_prompt(
            "path/to/your/summary_template.md",
            document="Another document to summarize.",
        ),
        "id": "query_2",
    },
]

results = await batch_model_query(
    prompt_info=prompt_info,
    model="databricks-llama-4-maverick",  # Specify your Databricks model endpoint
    process_func=extract_json_items,  # Optional: function to process raw text response
    batch_size=5, # Optional: Batch size before optional save
    max_concurrent_requests=3, # Optional: Max concurrent requests
    rate_limit=(2, 1), # Optional: Number of requests per second
    results_path="output_results/",  # Optional: path to save results
    run_name="my_batch_run",  # Optional: identifier for the run
    # token and host are automatically fetched from environment or dbutils if not provided
)

# Process results
for result in results:
    if result["error"]:
        print(f"Error processing prompt {result.get('id', 'N/A')}: {result['error']}")
    else:
        print(f"Result for prompt {result.get('id', 'N/A')}:")
        # Access raw message or processed response
        # print(result["message"])
        print(result["processed_response"])

Running in a Python File#

import asyncio
from dbutils_batch_query.model_query import (
    batch_model_query,
    extract_json_items
)
from dbutils_batch_query.prompts import load_prompt

user_prompt = load_prompt(
    "path/to/your/prompt_template.md", input_text="Some text to analyze."
)
# Example prompt information (replace with your actual prompts)
prompt_info = [
    {
        "system": "You are an assistant that extracts key information. Respond in a JSON codeblock.",
        "user": user_prompt,
        "id": "query_1",  # Optional: Add identifiers or other metadata
    },
    {
        "system": "You are an assistant that summarizes text. Respond in a JSON codeblock.",
        "user": load_prompt(
            "path/to/your/summary_template.md",
            document="Another document to summarize.",
        ),
        "id": "query_2",
    },
]

results = asyncio.run(
    batch_model_query(
        prompt_info=prompt_info,
        model="databricks-llama-4-maverick",  # Specify your Databricks model endpoint
        process_func=extract_json_items,  # Optional: function to process raw text response
        batch_size=5, # Optional: Batch size before optional save
        max_concurrent_requests=3, # Optional: Max concurrent requests
        rate_limit=(2, 1), # Optional: Number of requests per second
        results_path="output_results/",  # Optional: path to save results
        run_name="my_batch_run",  # Optional: identifier for the run
        # token and host are automatically fetched from environment or dbutils if not provided
    )
)

# Process results
for result in results:
    if result["error"]:
        print(f"Error processing prompt {result.get('id', 'N/A')}: {result['error']}")
    else:
        print(f"Result for prompt {result.get('id', 'N/A')}:")
        # Access raw message or processed response
        # print(result["message"])
        print(result["processed_response"])