# ============================ # Databricks: LLM backfill to Delta (STRING keys only) # ============================ import os import time from datetime import datetime from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional import httpx from openai import AzureOpenAI from pyspark.sql import functions as F, types as T from delta.tables import DeltaTable # ---------- CONFIG ---------- TABLE_FQN = "_xxxxx_assetdigital_dev.xxxxxxx.Asset_Digital_AI" MAX_ROWS_PER_RUN = 50 DEFAULT_MODEL = "gpt-4o" MAX_WORKERS = 4 # 🔐 Use Databricks secrets in production: # endpoint = dbutils.secrets.get("aoai-scope", "endpoint") # api_key = dbutils.secrets.get("aoai-scope", "api-key") endpoint = "https://xxxxxxx.a03.azurefd.net/rtiodatabricksapps" api_key = dbutils.secrets.get(scope="digitalAI", key="digitalAIKey") api_version = "2025-03-01-preview" verify_tls = "true" if not endpoint: raise ValueError("Please set a valid AZURE_OPENAI_ENDPOINT.") if not api_key or api_key == "REPLACE_WITH_SECRET": raise ValueError("Please set AZURE_OPENAI_API_KEY via Databricks secrets or env vars.") # ---------- Azure OpenAI client + helper ---------- client = AzureOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version=api_version, http_client=httpx.Client(verify=bool(verify_tls), timeout=60.0), ) def ask_llm( prompt: str, model: str = DEFAULT_MODEL, system_prompt: str = "You are a helpful assistant.", temperature: float = 0.2, max_retries: int = 5, backoff_base: float = 1.8, ) -> str: """ Calls Azure OpenAI Chat Completions with retries. Returns text content. """ last_err: Optional[Exception] = None for attempt in range(max_retries): try: resp = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt} ], temperature=temperature, ) choice = resp.choices[0] content = getattr(choice, "message", None).content if hasattr(choice, "message") else choice["message"]["content"] return content or "" except Exception as e: last_err = e sleep_s = min((backoff_base ** attempt) + (0.1 * attempt), 30) time.sleep(sleep_s) raise RuntimeError(f"LLM call failed after {max_retries} attempts: {last_err}") spark.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true") spark.conf.set("spark.sql.shuffle.partitions", "200") # ---------- Read table & verify schema ---------- df = spark.table(TABLE_FQN) df.printSchema() required_cols = {"query_key", "query"} missing = required_cols - set(df.columns) if missing: raise ValueError(f"Table {TABLE_FQN} missing required columns: {missing}") # Expect both keys to be STRING now schema_by_name = {f.name: f.dataType for f in df.schema.fields} if str(schema_by_name.get("query_key")).lower().startswith("binary"): raise ValueError("query_key is still BINARY. Please finish migration to STRING first.") if "answer_key" in df.columns and str(schema_by_name.get("answer_key")).lower().startswith("binary"): raise ValueError("answer_key is still BINARY. Please finish migration to STRING first.") has_llm_col = "llm" in df.columns # Pending rows: answer_key is NULL OR different from query_key (string comparison) pending_cond = (F.col("answer_key").isNull()) | (F.col("answer_key") != F.col("query_key")) select_cols = ["query", F.col("query_key").alias("join_key")] + (["llm"] if has_llm_col else []) pending_df = df.select(*select_cols).where(pending_cond).limit(MAX_ROWS_PER_RUN) rows = pending_df.collect() if not rows: print("No pending rows to process.") else: print(f"Processing {len(rows)} rows from {TABLE_FQN}...") # Prepare tasks work_items = [] for r in rows: q_text = r["query"] join_key = r["join_key"] # STRING model = (r["llm"] if has_llm_col and r["llm"] else DEFAULT_MODEL) if q_text and str(q_text).strip(): work_items.append((join_key, q_text, model)) print(f"Key: {join_key}") results = [] errors = [] def _process_item(item): join_key, q_text, mdl = item try: answer_text = ask_llm(prompt=q_text, model=mdl) if not answer_text.strip(): raise ValueError("Empty LLM answer") return { "join_key": join_key, "answer": answer_text, "answer_datetime": datetime.utcnow() }, None except Exception as e: return None, (join_key, str(e)) if work_items: with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool: futures = [pool.submit(_process_item, it) for it in work_items] for fut in as_completed(futures): upd, err = fut.result() if upd: results.append(upd) if err: errors.append(err) if results: updates_sdf = spark.createDataFrame( results, schema=T.StructType([ T.StructField("join_key", T.StringType(), False), T.StructField("answer", T.StringType(), True), T.StructField("answer_datetime", T.TimestampType(), True), ]) ).withColumn("target_qk", F.col("join_key")) delta_tbl = DeltaTable.forName(spark, TABLE_FQN) # Build update map (string keys) set_map = { "answer": "u.answer", "answer_datetime": "u.answer_datetime", "answer_key": "t.query_key" # idempotency marker } # Ensure the answer columns exist (adds if missing) existing_cols = set(df.columns) to_add = [] if "answer" not in existing_cols: to_add.append("ADD COLUMNS (answer STRING)") if "answer_key" not in existing_cols: to_add.append("ADD COLUMNS (answer_key STRING)") if "answer_datetime" not in existing_cols: to_add.append("ADD COLUMNS (answer_datetime TIMESTAMP)") for stmt in to_add: spark.sql(f"ALTER TABLE {TABLE_FQN} {stmt}") ( delta_tbl.alias("t") .merge(updates_sdf.alias("u"), "t.query_key = u.target_qk") .whenMatchedUpdate(set=set_map) .execute() ) print(f"Updated {len(results)} rows in {TABLE_FQN}.") print("Updated keys:") for r in results: print(f" - {r['join_key']}") else: print("No updates generated.") if errors: print(f"{len(errors)} errors encountered (skipped). Showing up to 5:") for jk, msg in errors[:5]: print(f" - key={jk}: {msg}")