Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
min_version,
optional_import,
pprint_edges,
safe_eval,
)

validate, _ = optional_import("jsonschema", name="validate")
Expand Down Expand Up @@ -161,7 +162,7 @@ def _get_fake_spatial_shape(shape: Sequence[str | int], p: int = 1, n: int = 1,
for c in _get_var_names(i):
if c not in ["p", "n"]:
raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.")
ret.append(eval(i, {"p": p, "n": n}))
ret.append(safe_eval(i, {"p": p, "n": n}))
Comment thread
ericspod marked this conversation as resolved.
else:
raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.")
return tuple(ret)
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
torch_profiler_time_cpu_gpu,
torch_profiler_time_end_to_end,
)
from .safeeval import SAFE_TYPES, safe_eval
from .state_cacher import StateCacher
from .tf32 import detect_default_tf32, has_ampere_or_later
from .type_conversion import (
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _order_template(self, template: np.ndarray) -> np.ndarray:
else:
rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2])

sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)
sequence = getattr(self, f"{self.ordering_type}_idx")(rows, columns, depths)

ordering = np.array([template[tuple(e)] for e in sequence])

Expand Down
73 changes: 73 additions & 0 deletions monai/utils/safeeval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import ast
from collections.abc import Mapping, Sequence
from typing import Any

__all__ = ["SAFE_TYPES", "safe_eval"]

# default set of safe AST node types
SAFE_TYPES: Sequence[ast.AST] = (
ast.Expression,
ast.Name,
ast.Load,
ast.Constant,
ast.BinOp,
ast.UnaryOp,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Pow,
ast.Mod,
ast.USub,
ast.UAdd,
)


def safe_eval(
expr: str,
globals_vars: Mapping[str, Any] | None = None,
locals_vars: Mapping[str, object] | None = None,
allowed_types: Sequence[type] = SAFE_TYPES,
) -> Any:
"""
Evaluate the Python expression `expr` using `eval`, but only if it is a safe expression in that its parsed AST
contains nodes whose types are given in `allowed_types`. This ensures unsafe node types are excluded, if these
are present in the AST a ValueError is raised. The default set of such types in `SAFE_TYPES` ensures only
expressions with constants and names can be evaluated, so excludes attribute access, indexing, and calls. Code
injection is infeasible through such expressions, so this is a safe and secure way of evaluating simple expressions.

Args:
expr: expression to evaluate, this will be stripped before parsing to avoid indentation complaints
globals_vars: global variable mapping
locals_vars: local variable mapping
allowed_types: sequence of allowed AST types which can be found in `expr` when parsed

Raises:
ValueError: raised when any node in the AST parsed from `expr` has a type not in `allowed_types`

Returns:
The evaluated expression value, using `eval` with `globals_vars` and `locals_vars`
"""
parsed = ast.parse(expr.strip(), mode="eval")

# collect nodes in the AST which aren't permitted and unparse them for inclusion in the exception message
disallowed = [ast.unparse(n) for n in ast.walk(parsed) if not isinstance(n, tuple(allowed_types))]

if disallowed:
raise ValueError(f"Unsafe expression `{expr}` not evaluated, contains disallowed components: {disallowed}")

return eval(expr, dict(globals_vars), locals_vars)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
5 changes: 4 additions & 1 deletion tests/utils/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@


class TestModuleAlias(unittest.TestCase):
"""check that 'import monai.xx.file_name' returns a module"""
"""
Check that 'import monai.xx.file_name' returns a module. Note that this test will fail if a module has the same name
as a member of that module (or any other) which is imported in a `__init__.py` file.
"""

def test_files(self):
src_dir = os.path.dirname(TESTS_PATH)
Expand Down
58 changes: 58 additions & 0 deletions tests/utils/test_safe_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import ast
import unittest

from parameterized import parameterized

from monai.utils import safe_eval

GOOD_EXPRS = [
("1+2", None, None, 3),
(" 1 + 2 ", None, None, 3),
("1+2+x", {"x": 4}, None, 7),
("1+2+x", None, {"x": 4}, 7),
("1*2+x", {"x": 4}, None, 6),
("(1+2)*3", None, None, 9),
("foo+bar", {"foo": 1030}, {"bar": 204}, 1234),
]

BAD_EXPRS = [("foo()",), ("foo.bar",), ("foo[123]",), ("(1,2)",), ("[3,4]",), ("int.__class__.__init__.__globals__",)]


class TestSafeEval(unittest.TestCase):
@parameterized.expand(GOOD_EXPRS)
def test_good_exprs(self, expr, globals_vars, locals_vars, expected):
"""Test valid expressions with globals/locals evaluate to correct values."""
result = safe_eval(expr, globals_vars, locals_vars)
self.assertEqual(result, expected)

@parameterized.expand(BAD_EXPRS)
def test_bad_exprs(self, expr):
"""Test bad expressions correctly raise ValueError."""
with self.assertRaises(ValueError):
safe_eval(expr)

def test_allowed_types(self):
"""Test restricting the allowed list of types."""
allowed = [ast.Expression, ast.Constant, ast.BinOp, ast.Add]
result = safe_eval("1+2", allowed_types=allowed)
self.assertEqual(result, 3)

with self.assertRaises(ValueError):
safe_eval("1*2", allowed_types=allowed)


if __name__ == "__main__":
unittest.main()
Loading