No need to recheck close method before calling it

A to-be-closed variable is constant and it must have a close metamethod
when it is created. A program has to go out of its way (e.g., by
changing the variable's metamethod) to invalidate that check. So,
it is not worth to test that again. If the program tampers with the
metamethod, Lua will raise a regular error when attempting to call it.
This commit is contained in:
Roberto Ierusalimschy
2020-12-29 10:23:02 -03:00
parent 6188f3a654
commit 59e565d955
2 changed files with 63 additions and 27 deletions

40
lfunc.c
View File

@@ -101,31 +101,32 @@ UpVal *luaF_findupval (lua_State *L, StkId level) {
/* /*
** Prepare closing method plus its arguments for object 'obj' with ** Call closing method for object 'obj' with error message 'err'.
** error message 'err'. (This function assumes EXTRA_STACK.) ** (This function assumes EXTRA_STACK.)
*/ */
static int prepclosingmethod (lua_State *L, TValue *obj, TValue *err) { static void callclosemethod (lua_State *L, TValue *obj, TValue *err) {
StkId top = L->top; StkId top = L->top;
const TValue *tm = luaT_gettmbyobj(L, obj, TM_CLOSE); const TValue *tm = luaT_gettmbyobj(L, obj, TM_CLOSE);
if (ttisnil(tm)) /* no metamethod? */
return 0; /* nothing to call */
setobj2s(L, top, tm); /* will call metamethod... */ setobj2s(L, top, tm); /* will call metamethod... */
setobj2s(L, top + 1, obj); /* with 'self' as the 1st argument */ setobj2s(L, top + 1, obj); /* with 'self' as the 1st argument */
setobj2s(L, top + 2, err); /* and error msg. as 2nd argument */ setobj2s(L, top + 2, err); /* and error msg. as 2nd argument */
L->top = top + 3; /* add function and arguments */ L->top = top + 3; /* add function and arguments */
return 1; luaD_callnoyield(L, top, 0); /* call method */
} }
/* /*
** Raise an error with message 'msg', inserting the name of the ** Check whether 'obj' has a close metamethod and raise an error
** local variable at position 'level' in the stack. ** if not.
*/ */
static void varerror (lua_State *L, StkId level, const char *msg) { static void checkclosemth (lua_State *L, StkId level, const TValue *obj) {
int idx = cast_int(level - L->ci->func); const TValue *tm = luaT_gettmbyobj(L, obj, TM_CLOSE);
if (ttisnil(tm)) { /* no metamethod? */
int idx = cast_int(level - L->ci->func); /* variable index */
const char *vname = luaG_findlocal(L, L->ci, idx, NULL); const char *vname = luaG_findlocal(L, L->ci, idx, NULL);
if (vname == NULL) vname = "?"; if (vname == NULL) vname = "?";
luaG_runerror(L, msg, vname); luaG_runerror(L, "variable '%s' got a non-closable value", vname);
}
} }
@@ -136,7 +137,7 @@ static void varerror (lua_State *L, StkId level, const char *msg) {
** the 'level' of the upvalue being closed, as everything after that ** the 'level' of the upvalue being closed, as everything after that
** won't be used again. ** won't be used again.
*/ */
static void callclosemth (lua_State *L, StkId level, int status) { static void prepcallclosemth (lua_State *L, StkId level, int status) {
TValue *uv = s2v(level); /* value being closed */ TValue *uv = s2v(level); /* value being closed */
TValue *errobj; TValue *errobj;
if (status == CLOSEKTOP) if (status == CLOSEKTOP)
@@ -145,10 +146,7 @@ static void callclosemth (lua_State *L, StkId level, int status) {
errobj = s2v(level + 1); /* error object goes after 'uv' */ errobj = s2v(level + 1); /* error object goes after 'uv' */
luaD_seterrorobj(L, status, level + 1); /* set error object */ luaD_seterrorobj(L, status, level + 1); /* set error object */
} }
if (prepclosingmethod(L, uv, errobj)) /* something to call? */ callclosemethod(L, uv, errobj);
luaD_callnoyield(L, L->top - 3, 0); /* call method */
else if (!l_isfalse(uv)) /* non-closable non-false value? */
varerror(L, level, "attempt to close non-closable variable '%s'");
} }
@@ -171,16 +169,12 @@ void luaF_newtbcupval (lua_State *L, StkId level) {
lua_assert(L->openupval == NULL || uplevel(L->openupval) < level); lua_assert(L->openupval == NULL || uplevel(L->openupval) < level);
if (!l_isfalse(obj)) { /* false doesn't need to be closed */ if (!l_isfalse(obj)) { /* false doesn't need to be closed */
int status; int status;
const TValue *tm = luaT_gettmbyobj(L, obj, TM_CLOSE); checkclosemth(L, level, obj);
if (ttisnil(tm)) /* no metamethod? */
varerror(L, level, "variable '%s' got a non-closable value");
status = luaD_rawrunprotected(L, trynewtbcupval, level); status = luaD_rawrunprotected(L, trynewtbcupval, level);
if (unlikely(status != LUA_OK)) { /* memory error creating upvalue? */ if (unlikely(status != LUA_OK)) { /* memory error creating upvalue? */
lua_assert(status == LUA_ERRMEM); lua_assert(status == LUA_ERRMEM);
luaD_seterrorobj(L, LUA_ERRMEM, level + 1); /* save error message */ luaD_seterrorobj(L, LUA_ERRMEM, level + 1); /* save error message */
/* next call must succeed, as object is closable */ callclosemethod(L, s2v(level), s2v(level + 1));
prepclosingmethod(L, s2v(level), s2v(level + 1));
luaD_callnoyield(L, L->top - 3, 0); /* call method */
luaD_throw(L, LUA_ERRMEM); /* throw memory error */ luaD_throw(L, LUA_ERRMEM); /* throw memory error */
} }
} }
@@ -215,7 +209,7 @@ void luaF_close (lua_State *L, StkId level, int status) {
} }
if (uv->tbc && status != NOCLOSINGMETH) { if (uv->tbc && status != NOCLOSINGMETH) {
ptrdiff_t levelrel = savestack(L, level); ptrdiff_t levelrel = savestack(L, level);
callclosemth(L, upl, status); /* may change the stack */ prepcallclosemth(L, upl, status); /* may change the stack */
level = restorestack(L, levelrel); level = restorestack(L, levelrel);
} }
} }

View File

@@ -459,8 +459,50 @@ do -- errors due to non-closable values
getmetatable(xyz).__close = nil -- remove metamethod getmetatable(xyz).__close = nil -- remove metamethod
end end
local stat, msg = pcall(foo) local stat, msg = pcall(foo)
assert(not stat and assert(not stat and string.find(msg, "attempt to call a nil value"))
string.find(msg, "attempt to close non%-closable variable 'xyz'")) end
do -- tbc inside close methods
local track = {}
local function foo ()
local x <close> = func2close(function ()
local xx <close> = func2close(function (_, msg)
assert(msg == nil)
track[#track + 1] = "xx"
end)
track[#track + 1] = "x"
end)
track[#track + 1] = "foo"
return 20, 30, 40
end
local a, b, c, d = foo()
assert(a == 20 and b == 30 and c == 40 and d == nil)
assert(track[1] == "foo" and track[2] == "x" and track[3] == "xx")
-- again, with errors
local track = {}
local function foo ()
local x0 <close> = func2close(function (_, msg)
assert(msg == 202)
track[#track + 1] = "x0"
end)
local x <close> = func2close(function ()
local xx <close> = func2close(function (_, msg)
assert(msg == 101)
track[#track + 1] = "xx"
error(202)
end)
track[#track + 1] = "x"
error(101)
end)
track[#track + 1] = "foo"
return 20, 30, 40
end
local st, msg = pcall(foo)
assert(not st and msg == 202)
assert(track[1] == "foo" and track[2] == "x" and track[3] == "xx" and
track[4] == "x0")
end end