Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
additional_dependencies:
- fastapi
- pytest
- sqlalchemy

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.15.18'
Expand Down
9 changes: 3 additions & 6 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import html
from typing import TYPE_CHECKING

from config import get_config
from database.schema.base import UntypedRow
from schemas.datasets.openml import DatasetFileFormat

if TYPE_CHECKING:
from sqlalchemy.engine import Row


def _str_to_bool(string: str) -> bool:
if string.casefold() in ["true", "1", "yes", "y"]:
Expand All @@ -17,7 +14,7 @@ def _str_to_bool(string: str) -> bool:
raise ValueError(msg)


def _format_parquet_url(dataset: Row) -> str | None:
def _format_parquet_url(dataset: UntypedRow) -> str | None:
if dataset.format.lower() != DatasetFileFormat.ARFF:
return None

Expand All @@ -27,7 +24,7 @@ def _format_parquet_url(dataset: Row) -> str | None:
return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq"


def _format_dataset_url(dataset: Row) -> str:
def _format_dataset_url(dataset: UntypedRow) -> str:
base_url = get_config().routing.server_url
filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}"
return f"{base_url}data/v1/download/{dataset.file_id}/{filename}"
Expand Down
14 changes: 8 additions & 6 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
from schemas.datasets.openml import DatasetStatus, Feature

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
async def get(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand All @@ -35,7 +35,7 @@ async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
return row.one_or_none()


async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None:
async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand All @@ -53,7 +53,7 @@ async def get_tag(
dataset_id: Identifier,
tag: TagString,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
return (
await connection.execute(
text(
Expand Down Expand Up @@ -111,6 +111,8 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
},
)
except IntegrityError as e:
if e.orig is None:
raise
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
Expand All @@ -122,7 +124,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
async def get_description(
id_: Identifier,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
"""Get the most recent description for the dataset."""
row = await connection.execute(
text(
Expand Down Expand Up @@ -160,7 +162,7 @@ async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetSta
async def get_latest_processing_update(
dataset_id: Identifier,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand Down
15 changes: 8 additions & 7 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from sqlalchemy import Row, text
from sqlalchemy import text

from core.formatting import _str_to_bool
from database.schema.base import UntypedRow
from schemas.datasets.openml import EstimationProcedure

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
async def get_math_functions(
function_type: str,
connection: AsyncConnection,
) -> Sequence[UntypedRow]:
rows = await connection.execute(
text(
"""
Expand All @@ -21,10 +25,7 @@ async def get_math_functions(function_type: str, connection: AsyncConnection) ->
),
parameters={"function_type": function_type},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
Expand Down
27 changes: 13 additions & 14 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from sqlalchemy import Row, text
from sqlalchemy import text

from database.schema.base import UntypedRow
from routers.types import Identifier

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
Expand All @@ -20,10 +21,7 @@ async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence
),
parameters={"flow_id": for_flow},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
Expand All @@ -41,7 +39,7 @@ async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
return [tag.tag for tag in tag_rows]


async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
Expand All @@ -52,13 +50,14 @@ async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequenc
),
parameters={"flow_id": flow_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
async def get_by_name(
name: str,
external_version: str,
expdb: AsyncConnection,
) -> UntypedRow | None:
"""Get flow by name and external version."""
row = await expdb.execute(
text(
Expand All @@ -73,7 +72,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection)
return row.one_or_none()


async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
row = await expdb.execute(
text(
"""
Expand Down
24 changes: 11 additions & 13 deletions src/database/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, bindparam, text
from sqlalchemy import bindparam, text

from database.schema.base import UntypedRow
from routers.types import Identifier

if TYPE_CHECKING:
Expand All @@ -26,7 +27,7 @@ async def exist(id_: Identifier, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())


async def get(run_id: Identifier, expdb: AsyncConnection) -> Row | None:
async def get(run_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
"""Fetch the core run row from the `run` table.

Returns the row if found, or None if no run with `run_id` exists.
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]:
return [row.tag for row in rows.all()]


async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch the dataset(s) used as input for a run, with name and url.

Joins `input_data` with `dataset` to include the dataset name and ARFF URL.
Expand All @@ -79,10 +80,10 @@ async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch output files attached to a run from the `runfile` table.

Typical entries include the description XML and predictions ARFF.
Expand All @@ -98,15 +99,15 @@ async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_evaluations(
run_id: int,
expdb: AsyncConnection,
*,
evaluation_engine_ids: list[int],
) -> list[Row]:
) -> list[UntypedRow]:
"""Fetch evaluation metric results for a run.

Joins `evaluation` with `math_function` to resolve the metric name
Expand Down Expand Up @@ -138,10 +139,10 @@ async def get_evaluations(
query,
parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[UntypedRow]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
Expand All @@ -153,7 +154,4 @@ async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
),
parameters={"run_id": run_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()
1 change: 1 addition & 0 deletions src/database/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Defines Object-Relational Mappings (ORM)."""
52 changes: 52 additions & 0 deletions src/database/schema/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Base classes for all ORM classes.

When defining a new ORM class, use both `Base` and one of the `DeferredReflection` subclasses to
make sure that the class is populated with attributes that may not be defined explicitly.
For example, when creating a new mapping for a table from the `openml_expdb` database, use:

class ClassName(ExpDBReflected, Base):
__tablename__ = "class_names"

# any columns you wanted mapped explicitly
...

"""

from typing import Any

from sqlalchemy import Row
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import DeclarativeBase

from database.setup import expdb_database, user_database

UntypedRow = Row[Any]


class Base(DeclarativeBase):
"""Base class for all ORM classes."""


class ExpDBReflected(DeferredReflection):
"""Base class for ORM classes to map onto a table in the `openml_expdb` database."""

__abstract__ = True


class UserDBReflected(DeferredReflection):
"""Base class for ORM classes to map onto a table in the `openml` database."""

__abstract__ = True


async def reflect_db_schemas() -> None:
"""Populate defined ORM classes with attributes defined from columns in the database.

For example, the `dataset` class would automatically get a `collection_date` attribute,
even if it wasn't explicitly declared in the class definition,
because the `openml_expdb.dataset` table has a column `collection_date`.
"""
async with user_database().connect() as connection:
await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it.
async with expdb_database().connect() as connection:
await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above.
30 changes: 30 additions & 0 deletions src/database/schema/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""ORM classes for the *_tag tables (task_tag, ...)."""

from datetime import datetime

from sqlalchemy import FetchedValue
from sqlalchemy.orm import Mapped, mapped_column

from database.schema.base import Base, ExpDBReflected
from routers.types import Identifier, TagString


class Tag:
"""Base class for all of the *_tag tables."""

# The identifier of the entity that is tagged (e.g., dataset id, task id)
entity_id: Mapped[Identifier] = mapped_column("id", primary_key=True)
tag: Mapped[TagString] = mapped_column(primary_key=True)
uploader_id: Mapped[Identifier] = mapped_column("uploader")
creation_date: Mapped[datetime] = mapped_column("date", server_default=FetchedValue())


class TaskTag(ExpDBReflected, Tag, Base):
"""Tags belonging to a task."""

__tablename__ = "task_tag"

@property
def task_id(self) -> Identifier:
"""Identifier of the task which is tagged by this tag."""
return self.entity_id
Loading
Loading