batch_model_query#
- async batch_model_query(*, prompt_info: list[dict], model: str, process_func: callable = None, process_func_params: dict = None, batch_size: int = 10, max_concurrent_requests: int = 5, rate_limit: tuple = (2, 1), results_path: str | Path = None, run_name: str | Path = None, model_params: dict = None, token: str = None, host: str = None, **kwargs) list[dict] [source]#
Asynchronously process a batch of prompts using a language model client, saving intermediate and final results, and handling concurrent API requests with robust error handling and metadata tracking.
This function orchestrates the batch processing pipeline for querying language models. It divides the input prompts into batches, manages concurrency, applies optional post-processing, and saves results in both pickle and parquet formats. Intermediate results are saved after each batch for fault tolerance, and final results are consolidated at the end.
- Parameters:
- prompt_infolist of dict
List of dictionaries containing prompt information, each with
system
anduser
keys.- modelstr
The name or identifier of the language model to use.
- process_funccallable, optional
Function to process each response. If None, raw message content is returned in the
message
field. Default isNone
.- process_func_paramsdict, optional
Parameters to pass to
process_func
. Ignored ifprocess_func
is None. Default isNone
.- batch_sizeint, optional
Number of prompts to process between intermittent saves. Default is
10
.- max_concurrent_requestsint, optional
Maximum number of concurrent API requests allowed. Default is
5
.- rate_limittuple, optional
Tuple of (max_requests, interval_seconds) specifying the maximum number of requests allowed per interval. Default is
(2, 1)
, meaning 2 requests per 1 second. This is enforced using a rate limiter to avoid exceeding API provider quotas or triggering throttling.- results_pathstr or Path, optional
Path to save intermediate and final result files. If None, results are not saved. Default is
None
.- run_namestr or Path, optional
Name used to identify this batch run in saved files. Required if
results_path
is provided. Default isNone
.- model_paramsdict, optional
Dictionary of model parameters to override defaults. Supported keys: -
max_tokens
: Maximum tokens for completion (default: 2048) -temperature
: Sampling temperature (default: 0) Default isNone
.- tokenstr, optional
API token for authentication. If not provided, will be loaded from environment or Databricks context.
- hoststr, optional
API host URL. If not provided, will be loaded from environment or Databricks context.
- Returns:
- list of dict
A list of dictionaries, each containing the response and associated metadata for a prompt. Each dictionary includes:
message
: Raw response content from the model.processed_response
: Processed content ifprocess_func
is provided.chat
: Full API response object (orNone
on error).error
: Error message if an exception occurred,None
otherwise.model
: Model name used for generation.temperature
: Temperature setting used for generation.max_tokens
: Maximum tokens setting used.prompt_tokens
: Number of tokens in the prompt (if available).completion_tokens
: Number of tokens in the completion (if available).total_tokens
: Total number of tokens used (if available).timing
: Query execution time in seconds.All original keys from the corresponding entry in
prompt_info
.
Note
When
process_func
is None, the function returns the raw message content in themessage
field.After each batch, results (including
chat
objects) are saved as pickle files, and a version without thechat
andmessage
keys is saved as a parquet file.Intermediate results are saved in a subdirectory named after
run_name
; final results are saved inresults_path
.Intermediate files are deleted after successful completion.
Warning
Ensure that
token
andhost
are set, either via arguments, environment variables, env dotfile, or Databricks context.If
results_path
is provided,run_name
must also be specified.
Examples
The
batch_model_query
function allows you to send multiple prompts to a specified model asynchronously.Running in a Notebook#
from dbutils_batch_query.model_query import batch_model_query, extract_json_items from dbutils_batch_query.prompts import load_prompt user_prompt = load_prompt( "path/to/your/prompt_template.md", input_text="Some text to analyze." ) # Example prompt information (replace with your actual prompts) prompt_info = [ { "system": "You are an assistant that extracts key information. Respond in a JSON codeblock.", "user": user_prompt, "id": "query_1", # Optional: Add identifiers or other metadata }, { "system": "You are an assistant that summarizes text. Respond in a JSON codeblock.", "user": load_prompt( "path/to/your/summary_template.md", document="Another document to summarize.", ), "id": "query_2", }, ] results = await batch_model_query( prompt_info=prompt_info, model="databricks-llama-4-maverick", # Specify your Databricks model endpoint process_func=extract_json_items, # Optional: function to process raw text response batch_size=5, # Optional: Batch size before optional save max_concurrent_requests=3, # Optional: Max concurrent requests rate_limit=(2, 1), # Optional: Number of requests per second results_path="output_results/", # Optional: path to save results run_name="my_batch_run", # Optional: identifier for the run # token and host are automatically fetched from environment or dbutils if not provided ) # Process results for result in results: if result["error"]: print(f"Error processing prompt {result.get('id', 'N/A')}: {result['error']}") else: print(f"Result for prompt {result.get('id', 'N/A')}:") # Access raw message or processed response # print(result["message"]) print(result["processed_response"])
Running in a Python File#
import asyncio from dbutils_batch_query.model_query import ( batch_model_query, extract_json_items ) from dbutils_batch_query.prompts import load_prompt user_prompt = load_prompt( "path/to/your/prompt_template.md", input_text="Some text to analyze." ) # Example prompt information (replace with your actual prompts) prompt_info = [ { "system": "You are an assistant that extracts key information. Respond in a JSON codeblock.", "user": user_prompt, "id": "query_1", # Optional: Add identifiers or other metadata }, { "system": "You are an assistant that summarizes text. Respond in a JSON codeblock.", "user": load_prompt( "path/to/your/summary_template.md", document="Another document to summarize.", ), "id": "query_2", }, ] results = asyncio.run( batch_model_query( prompt_info=prompt_info, model="databricks-llama-4-maverick", # Specify your Databricks model endpoint process_func=extract_json_items, # Optional: function to process raw text response batch_size=5, # Optional: Batch size before optional save max_concurrent_requests=3, # Optional: Max concurrent requests rate_limit=(2, 1), # Optional: Number of requests per second results_path="output_results/", # Optional: path to save results run_name="my_batch_run", # Optional: identifier for the run # token and host are automatically fetched from environment or dbutils if not provided ) ) # Process results for result in results: if result["error"]: print(f"Error processing prompt {result.get('id', 'N/A')}: {result['error']}") else: print(f"Result for prompt {result.get('id', 'N/A')}:") # Access raw message or processed response # print(result["message"]) print(result["processed_response"])