Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions custom_components/pyscript/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import sys
import time
import traceback
import weakref

import yaml
Expand Down Expand Up @@ -1090,12 +1091,18 @@ async def ast_while(self, arg):
async def ast_classdef(self, arg):
"""Evaluate class definition."""
bases = [(await self.aeval(base)) for base in arg.bases]
keywords = {kw.arg: await self.aeval(kw.value) for kw in arg.keywords}
metaclass = keywords.pop("metaclass", type(bases[0]) if bases else type)

if self.curr_func and arg.name in self.curr_func.global_names:
sym_table_assign = self.global_sym_table
else:
sym_table_assign = self.sym_table
sym_table_assign[arg.name] = EvalLocalVar(arg.name)
sym_table = {}
if hasattr(metaclass, "__prepare__"):
sym_table = metaclass.__prepare__(arg.name, tuple(bases), **keywords)
else:
sym_table = {}
self.sym_table_stack.append(self.sym_table)
self.sym_table = sym_table
for arg1 in arg.body:
Expand All @@ -1106,11 +1113,17 @@ async def ast_classdef(self, arg):
raise SyntaxError(f"{val.name()} statement outside loop")
self.sym_table = self.sym_table_stack.pop()

decorators = [await self.aeval(dec) for dec in arg.decorator_list]
sym_table["__init__evalfunc_wrap__"] = None
if "__init__" in sym_table:
sym_table["__init__evalfunc_wrap__"] = sym_table["__init__"]
del sym_table["__init__"]
sym_table_assign[arg.name].set(type(arg.name, tuple(bases), sym_table))
cls = metaclass(arg.name, tuple(bases), sym_table, **keywords)
if inspect.iscoroutine(cls):
cls = await cls
for dec_func in reversed(decorators):
cls = await self.call_func(dec_func, None, cls)
sym_table_assign[arg.name].set(cls)

async def ast_functiondef(self, arg, async_func=False):
"""Evaluate function definition."""
Expand Down Expand Up @@ -1487,7 +1500,11 @@ async def ast_augassign(self, arg):
await self.recurse_assign(arg.target, new_val)

async def ast_annassign(self, arg):
"""Execute type hint assignment statement (just ignore the type hint)."""
"""Execute type hint assignment statement and track __annotations__."""
if isinstance(arg.target, ast.Name):
annotations = self.sym_table.setdefault("__annotations__", {})
if arg.annotation:
annotations[arg.target.id] = await self.aeval(arg.annotation)
if arg.value is not None:
rhs = await self.aeval(arg.value)
await self.recurse_assign(arg.target, rhs)
Expand Down Expand Up @@ -1961,19 +1978,25 @@ async def call_func(self, func, func_name, *args, **kwargs):
if isinstance(func, (EvalFunc, EvalFuncVar)):
return await func.call(self, *args, **kwargs)
if inspect.isclass(func) and hasattr(func, "__init__evalfunc_wrap__"):
inst = func()
has_init_wrapper = getattr(func, "__init__evalfunc_wrap__") is not None
inst = func(*args, **kwargs) if not has_init_wrapper else func()
#
# we use weak references when we bind the method calls to the instance inst;
# otherwise these self references cause the object to not be deleted until
# it is later garbage collected
#
inst_weak = weakref.ref(inst)
for name in dir(inst):
value = getattr(inst, name)
try:
value = getattr(inst, name)
except AttributeError:
# same effect as hasattr (which also catches AttributeError)
# dir() may list names that aren't actually accessible attributes
continue
if type(value) is not EvalFuncVar:
continue
setattr(inst, name, EvalFuncVarClassInst(value.get_func(), value.get_ast_ctx(), inst_weak))
if getattr(func, "__init__evalfunc_wrap__") is not None:
if has_init_wrapper:
#
# since our __init__ function is async, call the renamed one
#
Expand Down Expand Up @@ -2197,11 +2220,9 @@ def format_exc(self, exc, lineno=None, col_offset=None, short=False, code_list=N
else:
mesg = f"Exception in <{self.filename}>:\n"
mesg += f"{type(exc).__name__}: {exc}"
#
# to get a more detailed traceback on exception (eg, when chasing an internal
# error), add an "import traceback" above, and uncomment this next line
#
# return mesg + "\n" + traceback.format_exc(-1)

if _LOGGER.isEnabledFor(logging.DEBUG):
mesg += "\n" + traceback.format_exc()
return mesg

def get_exception(self):
Expand Down
94 changes: 94 additions & 0 deletions tests/test_unit_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,100 @@
["x: int = [10, 20]; x", [10, 20]],
["Foo = type('Foo', (), {'x': 100}); Foo.x = 10; Foo.x", 10],
["Foo = type('Foo', (), {'x': 100}); Foo.x += 10; Foo.x", 110],
[
"""
from enum import IntEnum

class TestIntMode(IntEnum):
VAL1 = 1
VAL2 = 2
VAL3 = 3
[TestIntMode.VAL2 == 2, isinstance(TestIntMode.VAL3, IntEnum)]
""",
[True, True],
],
[
"""
from enum import StrEnum

class TestStrEnum(StrEnum):
VAL1 = "val1"
VAL2 = "val2"
VAL3 = "val3"
[TestStrEnum.VAL2 == "val2", isinstance(TestStrEnum.VAL3, StrEnum)]
""",
[True, True],
],
[
"""
from enum import Enum, EnumMeta

class Color(Enum):
RED = 1
BLUE = 2
[type(Color) is EnumMeta, isinstance(Color.RED, Color), list(Color.__members__.keys())]
""",
[True, True, ["RED", "BLUE"]],
],
[
"""
from dataclasses import dataclass

@dataclass()
class DT:
name: str
num: int = 32
obj1 = DT(name="abc")
obj2 = DT("xyz", 5)
[obj1.name, obj1.num, obj2.name, obj2.num]
""",
["abc", 32, "xyz", 5],
],
[
"""
class Meta(type):
def __new__(mcls, name, bases, ns, flag=False):
ns["flag"] = flag
return type.__new__(mcls, name, bases, ns)

class Foo(metaclass=Meta, flag=True):
pass
[Foo.flag, isinstance(Foo, Meta)]
""",
[True, True],
],
[
"""
def deco(label):
def wrap(cls):
cls.labels.append(label)
return cls
return wrap

@deco("first")
@deco("second")
class Decorated:
labels = []
Decorated.labels
""",
["second", "first"],
],
[
"""
hits = []

def anno():
hits.append("ok")
return int

class Annotated:
a: anno()
b: int = 3
c = "skip"
[hits, Annotated.__annotations__, Annotated.b, hasattr(Annotated, "c")]
""",
[["ok"], {"a": int, "b": int}, 3, True],
],
["Foo = [type('Foo', (), {'x': 100})]; Foo[0].x = 10; Foo[0].x", 10],
["Foo = [type('Foo', (), {'x': [100, 101]})]; Foo[0].x[1] = 10; Foo[0].x", [100, 10]],
["Foo = [type('Foo', (), {'x': [0, [[100, 101]]]})]; Foo[0].x[1][0][1] = 10; Foo[0].x[1]", [[100, 10]]],
Expand Down
Loading