Handles '__close' errors in coroutines in "coroutine style"

Errors in '__close' metamethods in coroutines are handled by the same
logic that handles other errors, through 'recover'.
This commit is contained in:
Roberto Ierusalimschy
2020-12-30 11:20:22 -03:00
parent 553b37ce4f
commit ce101dcaf7
2 changed files with 85 additions and 22 deletions

66
ldo.c
View File

@@ -103,7 +103,7 @@ void luaD_seterrorobj (lua_State *L, int errcode, StkId oldtop) {
break; break;
} }
default: { default: {
lua_assert(errcode >= LUA_ERRRUN); /* real error */ lua_assert(errorstatus(errcode)); /* real error */
setobjs2s(L, oldtop, L->top - 1); /* error message on current top */ setobjs2s(L, oldtop, L->top - 1); /* error message on current top */
break; break;
} }
@@ -593,15 +593,11 @@ static void finishCcall (lua_State *L, int status) {
/* /*
** Executes "full continuation" (everything in the stack) of a ** Executes "full continuation" (everything in the stack) of a
** previously interrupted coroutine until the stack is empty (or another ** previously interrupted coroutine until the stack is empty (or another
** interruption long-jumps out of the loop). If the coroutine is ** interruption long-jumps out of the loop).
** recovering from an error, 'ud' points to the error status, which must
** be passed to the first continuation function (otherwise the default
** status is LUA_YIELD).
*/ */
static void unroll (lua_State *L, void *ud) { static void unroll (lua_State *L, void *ud) {
CallInfo *ci; CallInfo *ci;
if (ud != NULL) /* error status? */ UNUSED(ud);
finishCcall(L, *(int *)ud); /* finish 'lua_pcallk' callee */
while ((ci = L->ci) != &L->base_ci) { /* something in the stack */ while ((ci = L->ci) != &L->base_ci) { /* something in the stack */
if (!isLua(ci)) /* C function? */ if (!isLua(ci)) /* C function? */
finishCcall(L, LUA_YIELD); /* complete its execution */ finishCcall(L, LUA_YIELD); /* complete its execution */
@@ -628,21 +624,36 @@ static CallInfo *findpcall (lua_State *L) {
/* /*
** Recovers from an error in a coroutine. Finds a recover point (if ** Auxiliary structure to call 'recover' in protected mode.
** there is one) and completes the execution of the interrupted
** 'luaD_pcall'. If there is no recover point, returns zero.
*/ */
static int recover (lua_State *L, int status) { struct RecoverS {
CallInfo *ci = findpcall(L); int status;
if (ci == NULL) return 0; /* no recovery point */ CallInfo *ci;
};
/*
** Recovers from an error in a coroutine: completes the execution of the
** interrupted 'luaD_pcall', completes the interrupted C function which
** called 'lua_pcallk', and continues running the coroutine. If there is
** an error in 'luaF_close', this function will be called again and the
** coroutine will continue from where it left.
*/
static void recover (lua_State *L, void *ud) {
struct RecoverS *r = cast(struct RecoverS *, ud);
int status = r->status;
CallInfo *ci = r->ci; /* recover point */
StkId func = restorestack(L, ci->u2.funcidx);
/* "finish" luaD_pcall */ /* "finish" luaD_pcall */
L->ci = ci; L->ci = ci;
L->allowhook = getoah(ci->callstatus); /* restore original 'allowhook' */ L->allowhook = getoah(ci->callstatus); /* restore original 'allowhook' */
status = luaD_closeprotected(L, ci->u2.funcidx, status); luaF_close(L, func, status); /* may change the stack */
luaD_seterrorobj(L, status, restorestack(L, ci->u2.funcidx)); func = restorestack(L, ci->u2.funcidx);
luaD_seterrorobj(L, status, func);
luaD_shrinkstack(L); /* restore stack size in case of overflow */ luaD_shrinkstack(L); /* restore stack size in case of overflow */
L->errfunc = ci->u.c.old_errfunc; L->errfunc = ci->u.c.old_errfunc;
return 1; /* continue running the coroutine */ finishCcall(L, status); /* finish 'lua_pcallk' callee */
unroll(L, NULL); /* continue running the coroutine */
} }
@@ -692,6 +703,24 @@ static void resume (lua_State *L, void *ud) {
} }
} }
/*
** Calls 'recover' in protected mode, repeating while there are
** recoverable errors, that is, errors inside a protected call. (Any
** error interrupts 'recover', and this loop protects it again so it
** can continue.) Stops with a normal end (status == LUA_OK), an yield
** (status == LUA_YIELD), or an unprotected error ('findpcall' doesn't
** find a recover point).
*/
static int p_recover (lua_State *L, int status) {
struct RecoverS r;
r.status = status;
while (errorstatus(status) && (r.ci = findpcall(L)) != NULL)
r.status = luaD_rawrunprotected(L, recover, &r);
return r.status;
}
LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs, LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs,
int *nresults) { int *nresults) {
int status; int status;
@@ -709,10 +738,7 @@ LUA_API int lua_resume (lua_State *L, lua_State *from, int nargs,
api_checknelems(L, (L->status == LUA_OK) ? nargs + 1 : nargs); api_checknelems(L, (L->status == LUA_OK) ? nargs + 1 : nargs);
status = luaD_rawrunprotected(L, resume, &nargs); status = luaD_rawrunprotected(L, resume, &nargs);
/* continue running after recoverable errors */ /* continue running after recoverable errors */
while (errorstatus(status) && recover(L, status)) { status = p_recover(L, status);
/* unroll continuation */
status = luaD_rawrunprotected(L, unroll, &status);
}
if (likely(!errorstatus(status))) if (likely(!errorstatus(status)))
lua_assert(status == L->status); /* normal end or yield */ lua_assert(status == L->status); /* normal end or yield */
else { /* unrecoverable error */ else { /* unrecoverable error */

View File

@@ -123,7 +123,7 @@ assert(#a == 22 and a[#a] == 79)
x, a = nil x, a = nil
-- coroutine closing print("to-be-closed variables in coroutines")
local function func2close (f) local function func2close (f)
return setmetatable({}, {__close = f}) return setmetatable({}, {__close = f})
@@ -189,7 +189,6 @@ do
local st, msg = coroutine.close(co) local st, msg = coroutine.close(co)
assert(st == false and coroutine.status(co) == "dead" and msg == 200) assert(st == false and coroutine.status(co) == "dead" and msg == 200)
assert(x == 200) assert(x == 200)
end end
do do
@@ -207,6 +206,44 @@ do
local st1, st2, err = coroutine.resume(co) local st1, st2, err = coroutine.resume(co)
assert(st1 and not st2 and err == 43) assert(st1 and not st2 and err == 43)
assert(X == 43 and Y.name == "pcall") assert(X == 43 and Y.name == "pcall")
-- recovering from errors in __close metamethods
local track = {}
local function h (o)
local hv <close> = o
return 1
end
local function foo ()
local x <close> = func2close(function(_,msg)
track[#track + 1] = msg or false
error(20)
end)
local y <close> = func2close(function(_,msg)
track[#track + 1] = msg or false
return 1000
end)
local z <close> = func2close(function(_,msg)
track[#track + 1] = msg or false
error(10)
end)
coroutine.yield(1)
h(func2close(function(_,msg)
track[#track + 1] = msg or false
error(2)
end))
end
local co = coroutine.create(pcall)
local st, res = coroutine.resume(co, foo) -- call 'foo' protected
assert(st and res == 1) -- yield 1
local st, res1, res2 = coroutine.resume(co) -- continue
assert(coroutine.status(co) == "dead")
assert(st and not res1 and res2 == 20) -- last error (20)
assert(track[1] == false and track[2] == 2 and track[3] == 10 and
track[4] == 10)
end end