Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 327 additions & 0 deletions src/unstract/clone/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def get_post_schema(self, entity_path: str) -> frozenset[str]:
self._post_schema_cache[entity_path] = writable
return writable

def probe(self, path: str) -> bool:
"""Capability probe: is this feature's route installed on this deployment?

GET ``path`` and return True on 200, False on 404 (route absent =
feature not built into this deployment). Any other status / transport
error re-raises — a real failure must not look like "feature missing".
"""
try:
self._request("GET", path)
except PlatformAPIError as e:
if e.status_code == 404:
return False
raise
return True

# ----- org users & groups -----

def list_users(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -238,6 +253,18 @@ def list_profiles(self, tool_id: str) -> list[dict[str, Any]]:
result = self._request("GET", f"prompt-studio/prompt-studio-profile/{tool_id}/")
return result if isinstance(result, list) else result.get("results", [])

def list_prompts(self, tool_id: str) -> list[dict[str, Any]]:
"""List a tool's prompts (``prompt_id`` + ``prompt_key`` per row).

Used to map source prompt ids to the target prompts created by
``import_project`` / ``sync_prompts`` (matched by ``prompt_key``),
so prompt-scoped cloud config can remap its FKs.
"""
result = self._request(
"GET", "prompt-studio/prompt/", params={"tool_id": tool_id}
)
return result if isinstance(result, list) else result.get("results", [])

def export_project(self, tool_id: str) -> dict[str, Any]:
"""Export a prompt-studio project as a portable JSON blob.

Expand Down Expand Up @@ -533,3 +560,303 @@ def create_api_key(self, payload: dict[str, Any]) -> dict[str, Any]:
and cannot be carried over from source.
"""
return self._request("POST", "api/keys/api/", json=payload)

# ----- lookups (cloud-only) -----

def list_lookup_definitions(self) -> list[dict[str, Any]]:
"""List lookup definitions in this org. Also the capability-probe path."""
result = self._request("GET", "lookups/definitions/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def get_lookup_definition(self, lookup_id: str) -> dict[str, Any]:
"""Fetch a lookup definition's detail.

Detail inlines the draft content: ``prompt_template``,
``draft_version_id``, ``input_vars``, and ``adapters`` (a dict with
``llm`` / ``x2text`` adapter UUIDs, either possibly ``None``).
"""
return self._request("GET", f"lookups/definitions/{lookup_id}/")

def create_lookup_definition(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create a lookup definition. Backend auto-creates an empty DRAFT
version with default adapters; populate it via the draft/adapters/file
endpoints below.
"""
return self._request("POST", "lookups/definitions/", json=payload)

def update_lookup_draft_template(
self, lookup_id: str, prompt_template: str
) -> dict[str, Any]:
"""Set the draft version's prompt template."""
return self._request(
"PATCH",
f"lookups/definitions/{lookup_id}/draft/",
json={"prompt_template": prompt_template},
)

def update_lookup_draft_adapters(
self, lookup_id: str, adapters: dict[str, str]
) -> dict[str, Any]:
"""Set the draft version's LLM and/or X2Text adapters by target UUID.

``adapters`` may carry either or both of ``llm`` / ``x2text``; absent
keys leave the existing draft adapter untouched.
"""
return self._request(
"PATCH",
f"lookups/definitions/{lookup_id}/adapters/",
json=adapters,
)

def list_lookup_files(self, lookup_id: str) -> list[dict[str, Any]]:
"""List a lookup's draft reference files (rows carry ``file_id``,
``file_name``, ``file_size``).
"""
result = self._request("GET", f"lookups/definitions/{lookup_id}/files/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def download_lookup_file(self, lookup_id: str, file_id: str) -> bytes:
"""Download a reference file's original bytes.

Returns raw bytes — the content route serves an ``HttpResponse`` body
(not a JSON envelope), so this bypasses the JSON-decoding request path.
"""
url = self._url(f"lookups/definitions/{lookup_id}/files/{file_id}/content/")
logger.debug("GET %s", url)
resp = self._session.get(url, timeout=self.timeout, verify=self.verify)
if not 200 <= resp.status_code < 300:
raise PlatformAPIError(
f"GET lookups/definitions/{lookup_id}/files/{file_id}/content/ "
f"returned {resp.status_code}",
status_code=resp.status_code,
body=resp.text[:2000],
)
return resp.content

def upload_lookup_file(
self, lookup_id: str, file_name: str, data: bytes, mime_type: str
) -> dict[str, Any]:
"""Upload a reference file into a lookup's draft version.

Backend writes bytes to storage, creates the row, and dispatches
re-extraction server-side. The draft enforces a unique filename per
version, so callers pre-check via ``list_lookup_files`` to avoid a 409.
"""
files = {"file": (file_name, data, mime_type)}
return self._request(
"POST", f"lookups/definitions/{lookup_id}/files/", files=files
)

def list_lookup_assignments(self) -> list[dict[str, Any]]:
"""List PromptLookupAssignment rows in this org.

Each row carries ``assignment_id``, ``prompt`` (src ToolStudioPrompt
uuid), ``version`` (src LookupVersion uuid), ``lookup_definition``
(src lookup_id), ``is_draft_version``, and ``variable_mappings``.
"""
result = self._request("GET", "lookups/assignments/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_lookup_assignment(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create a prompt-lookup assignment.

Writable: ``prompt``, ``lookup_definition`` (required), ``version``,
``variable_mappings``. Backend enforces one assignment per prompt
(``one_lookup_per_prompt``), so callers pre-check target assignments.
"""
return self._request("POST", "lookups/assignments/", json=payload)

# ----- manual review / HITL (cloud-only) -----
#
# Each workflow can hold one RuleEngine row per ``rule_type`` (DB / API)
# and one HITLSettings row. The "using_workflow" GET routes take the
# workflow id in the URL path and wrap the row in ``{"data": ...}``;
# they 404 (rules) / 500 (settings) when none exists — callers treat a
# missing row as "nothing to clone", not an error.

MR_RULE_TYPES: tuple[str, ...] = ("DB", "API")

def get_review_rule(
self, workflow_id: str, rule_type: str
) -> dict[str, Any] | None:
"""Fetch a workflow's RuleEngine row for one ``rule_type``.

Returns the rule dict (with nested ``confidence_filters``) or ``None``
when no rule of that type exists (backend answers 404).
"""
try:
body = self._request(
"GET",
f"manual_review/rule_engine/workflow/{workflow_id}/",
params={"rule_type": rule_type},
)
except PlatformAPIError as e:
if e.status_code == 404:
return None
raise
return (body or {}).get("data")

def create_review_rule(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create a RuleEngine row (+ nested ``confidence_filters``).

Writable: ``workflow`` (required), ``rule_type``, ``percentage``,
``rule_string``, ``rule_json``, ``rule_logic``, ``confidence_filters``.
Unique per (workflow, rule_type, organization).
"""
return self._request("POST", "manual_review/rule_engine/", json=payload)

def get_review_settings(self, workflow_id: str) -> dict[str, Any] | None:
"""Fetch a workflow's HITLSettings row, or ``None`` if absent.

The backend's ``settings_using_workflow`` raises ``DoesNotExist``
(→ 500) when no row exists rather than 404, so any error here is
treated as "no settings to clone".
"""
try:
body = self._request(
"GET", f"manual_review/settings/workflow/{workflow_id}/"
)
except PlatformAPIError:
return None
return (body or {}).get("data")
Comment thread
greptile-apps[bot] marked this conversation as resolved.

def create_review_settings(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create a HITLSettings row.

Writable: ``workflow`` (OneToOne, required), ``sync_with``,
``ttl_hours``.
"""
return self._request("POST", "manual_review/settings/", json=payload)

def list_auto_approval_settings(self) -> list[dict[str, Any]]:
"""List org-level AutoApprovalSettings (0 or 1 row per org).

Plain ModelViewSet ``list`` — 200s bare with no query params, so it
doubles as the manual-review capability probe path.
"""
result = self._request("GET", "manual_review/auto_approval_settings/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_auto_approval_settings(
self, payload: dict[str, Any]
) -> dict[str, Any]:
"""Create org-level AutoApprovalSettings.

Writable: ``auto_approved_document_classes``, ``auto_approved_users``.
``organization`` is server-set. Unique per organization.
"""
return self._request(
"POST", "manual_review/auto_approval_settings/", json=payload
)

def list_review_api_keys(self) -> list[dict[str, Any]]:
"""List ReviewApiKey rows in this org."""
result = self._request("GET", "manual_review/api/keys/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_review_api_key(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create a ReviewApiKey. The ``api_key`` secret is server-minted
(``uuid4``, non-editable) and cannot be carried over from source.
Writable: ``class_name``, ``description``, ``is_active``.
"""
return self._request("POST", "manual_review/api/key/", json=payload)

# ----- agentic studio (cloud-only) -----

def list_agentic_projects(self) -> list[dict[str, Any]]:
"""List agentic projects in this org. Also the capability-probe path.

Source platform key is a service account, which sees every project.
Rows carry ``id``, ``name``, ``description``, the four adapter FK ids
(``llm_connector_id`` / ``agent_llm_connector_id`` /
``lightweight_llm_connector_id`` / ``text_extractor_connector_id``),
and ``canary_fields``.
"""
result = self._request("GET", "agentic/projects/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_agentic_project(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create an agentic project. Returns the created row (carries ``id``)."""
return self._request("POST", "agentic/projects/", json=payload)

def list_agentic_prompt_versions(
self, *, project_id: str | None = None
) -> list[dict[str, Any]]:
"""List agentic prompt versions, optionally scoped to a project.

Rows carry ``id``, ``project``, ``version``, ``prompt_text``,
``accuracy``, ``is_active``, and the self-FK ``parent_version``.
"""
params: dict[str, Any] = {}
if project_id is not None:
params["project_id"] = project_id
result = self._request("GET", "agentic/prompt-versions/", params=params)
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_agentic_prompt_version(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create an agentic prompt version (flat endpoint, ``project`` in body)."""
return self._request("POST", "agentic/prompt-versions/", json=payload)

def list_agentic_schemas(
self, *, project_id: str | None = None
) -> list[dict[str, Any]]:
"""List agentic schemas, optionally scoped to a project.

Rows carry ``id``, ``project``, ``json_schema``, ``version``,
``is_active``.
"""
params: dict[str, Any] = {}
if project_id is not None:
params["project_id"] = project_id
result = self._request("GET", "agentic/schemas/", params=params)
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_agentic_schema(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create an agentic schema (flat endpoint, ``project`` in body)."""
return self._request("POST", "agentic/schemas/", json=payload)

def list_agentic_settings(self) -> list[dict[str, Any]]:
"""List agentic settings. Org-wide key/value rows (no project FK)."""
result = self._request("GET", "agentic/settings/")
return result if isinstance(result, list) else (result or {}).get("results", [])

def create_agentic_setting(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Create an org-wide agentic setting (``key`` is globally unique)."""
return self._request("POST", "agentic/settings/", json=payload)

def update_agentic_setting(
self, setting_id: str, payload: dict[str, Any]
) -> dict[str, Any]:
"""PATCH an existing agentic setting by id."""
return self._request("PATCH", f"agentic/settings/{setting_id}/", json=payload)

def export_agentic_project(self, project_id: str, *, force: bool = True) -> Any:
"""Republish ``AgenticStudioRegistry`` from the project's active
schema + prompt. Mirror of ``export_custom_tool``.

Backend requires an active schema and active prompt; ``force_export``
bypasses the wizard-completion guard. Records nothing — caller re-reads
the registry to learn the new id.
"""
return self._request(
"POST",
f"agentic/projects/{project_id}/export/",
json={
"is_shared_with_org": False,
"user_ids": [],
"force_export": force,
},
)

def list_agentic_registries(
self, *, agentic_project: str | None = None
) -> list[dict[str, Any]]:
"""List AgenticStudioRegistry rows. The list endpoint returns nothing
unless a filter is supplied; pass ``agentic_project`` to look up the
registry id for a given project.
"""
params: dict[str, Any] = {}
if agentic_project is not None:
params["agentic_project"] = agentic_project
result = self._request("GET", "agentic-studio-registry/", params=params)
return result if isinstance(result, list) else (result or {}).get("results", [])
17 changes: 17 additions & 0 deletions src/unstract/clone/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,20 @@ class CloneContext:
# touches them once per endpoint, never per resource).
share_cache: dict[str, Any] = field(default_factory=dict)
share_cache_lock: threading.Lock = field(default_factory=threading.Lock)
# Capability-probe memo: (id(client), feature_path) -> present?. Probed
# once per (deployment, feature) so cloud-phase gating costs one GET total.
probe_cache: dict[tuple[int, str], bool] = field(default_factory=dict)

def feature_present(self, client: "PlatformClient", path: str) -> bool:
"""Is ``path`` (a feature's list endpoint) installed on ``client``?

Memoised per run. # ponytail: plain dict, no lock — probing runs in
the single-threaded orchestrator loop, before any parallel_map fan-out.
"""
key = (id(client), path)
cached = self.probe_cache.get(key)
if cached is not None:
return cached
present = client.probe(path)
self.probe_cache[key] = present
return present
Loading
Loading