Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,11 @@ def _extract_state_from_messages(self, messages: Sequence[Message]) -> tuple[lis
if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json":
try:
uri = content.uri
if uri.startswith("data:application/json;base64,"): # type: ignore[union-attr]
prefix, _, encoded_data = uri.partition(",") # type: ignore[union-attr]
media_type, *parameters = prefix[5:].split(";")
if prefix.startswith("data:") and media_type == "application/json" and "base64" in parameters:
import base64

encoded_data = uri.split(",", 1)[1] # type: ignore[union-attr]
decoded_bytes = base64.b64decode(encoded_data)
state = json.loads(decoded_bytes.decode("utf-8"))
Comment on lines +300 to 304

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 829a628 by validating base64 with base64.b64decode(..., validate=True) and catching binascii.Error in the existing warning/fallthrough path. I also added a malformed-base64 regression test that preserves the original message list and returns state is None.

Validation run locally:

  • uv run pytest packages/ag-ui/tests/ag_ui/test_ag_ui_client.py -q
  • uv run ruff check packages/ag-ui/agent_framework_ag_ui/_client.py packages/ag-ui/tests/ag_ui/test_ag_ui_client.py
  • git diff --check -- python/packages/ag-ui/agent_framework_ag_ui/_client.py python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py


Expand Down
24 changes: 24 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ async def test_extract_state_from_messages_with_state(self) -> None:
assert result_messages[0].text == "Hello"
assert state == state_data

async def test_extract_state_from_messages_with_parameterized_data_uri(self) -> None:
"""Test state extraction from JSON data URIs with media type parameters."""
import base64

client = StubAGUIChatClient(endpoint="http://localhost:8888/")

state_data = {"key": "value", "count": 42}
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")

messages = [
Message(role="user", contents=["Hello"]),
Message(
role="user",
contents=[Content.from_uri(uri=f"data:application/json;charset=utf-8;base64,{state_b64}")],
),
]

result_messages, state = client.extract_state_from_messages(messages)

assert len(result_messages) == 1
assert result_messages[0].text == "Hello"
assert state == state_data

async def test_extract_state_invalid_json(self) -> None:
"""Test state extraction with invalid JSON."""
import base64
Expand Down