1
1
Fork 0
mirror of https://github.com/NixOS/nix.git synced 2025-11-25 11:49:35 +01:00

* Short-circuiting of function call evaluation.

With maximal laziness, you would expect that a function like this

    fib = n:
      if n == 0 then 0 else
      if n == 1 then 1 else
      builtins.add (fib (builtins.sub n 1)) (fib (builtins.sub n 2));

  can be evaluated efficiently, because maximal laziness should
  implictly memoize the recursive calls to "fib".  However, non-strictness
  interferes with this: the argument "n" is generally not in a form
  that allows the memoization to work (e.g., it will be something like
  (20 - 1 - 2 - 2) rather than 15).  By the time that "n" is
  evaluated (in "if n == 0 ..."), we're already deep in the evaluation
  of the call.

  (Strictness solves this:

      builtins.add (strict fib (builtins.sub n 1)) (strict fib (builtins.sub n 2));

  but that's not a very nice approach.)

  With short-circuiting, the evaluator will check after evaluating a
  term, whether that term is the argument of a function call that
  we're currently evaluating.  If so, it will check to see if the same
  call but with the evaluated argument is in the normal form cache.

  For instance, after evaluating (20 - 1 - 2 - 2) to 15, if we see
  that "fib (20 - 1 - 2 - 2)" is currently being evaluated, we check
  to see if "fib 15" is in the normal form cache.  If so, we unwind
  the stack (by throwing an exception) up to the evalExpr call
  responsible for "fib (20 - 1 - 2 - 2)", which can then immediately
  return the normal form for "fib 15".  And indeed this makes "fib"
  run in O(n) time.

  The overhead for checking the active function calls (which isn't
  very smart yet) seems to be modest, about 2% for "nix-env -qa
  --drv-path --out-path" on Nixpkgs.
This commit is contained in:
Eelco Dolstra 2007-10-12 17:53:47 +00:00
parent 74ce938e18
commit c3a79daaf3

View file

@ -18,6 +18,13 @@ namespace nix {
int cacheTerms;
bool shortCircuit;
#define maxActiveCalls 4096
ATerm activeCalls[maxActiveCalls];
unsigned int activeCallsCount = 0;
EvalState::EvalState()
@ -30,7 +37,10 @@ EvalState::EvalState()
addPrimOps();
if (!string2Int(getEnv("NIX_TERM_CACHE"), cacheTerms)) cacheTerms = 1;
shortCircuit = getEnv("NIX_SHORT_CIRCUIT", "0") == "1";
strictMode = getEnv("NIX_STRICT", "0") == "1";
ATprotectMemory(activeCalls, maxActiveCalls);
}
@ -746,6 +756,33 @@ Expr evalExpr2(EvalState & state, Expr e)
}
class ShortCircuit
{
};
unsigned int fnord;
void maybeShortCircuit(EvalState & state, Expr e, Expr nf)
{
for (unsigned int i = 0; i < activeCallsCount; ++i) {
Expr fun, arg;
if (!matchCall(activeCalls[i], fun, arg)) abort();
if (arg == e) {
//printMsg(lvlError, format("blaat"));
//printMsg(lvlError, format("blaat %1% %2% %3%") % fun % arg % e);
Expr res = state.normalForms.get(makeCall(fun, nf));
if (res) {
fnord++;
//printMsg(lvlError, format("blaat"));
throw ShortCircuit();
}
}
}
}
Expr evalExpr(EvalState & state, Expr e)
{
checkInterrupt();
@ -769,22 +806,86 @@ Expr evalExpr(EvalState & state, Expr e)
previously evaluated expressions. */
Expr nf = state.normalForms.get(e);
if (nf) {
if (nf == makeBlackHole())
throwEvalError("infinite recursion encountered");
//if (nf == makeBlackHole())
// throwEvalError("infinite recursion encountered");
state.nrCached++;
return nf;
}
/* Otherwise, evaluate and memoize. */
state.normalForms.set(e, makeBlackHole());
try {
nf = evalExpr2(state, e);
} catch (Error & err) {
state.normalForms.remove(e);
throw;
Expr fun, arg;
if (shortCircuit && matchCall(e, fun, arg)) {
#if 0
Expr arg2 = state.normalForms.get(arg);
if (arg2) { /* the evaluated argument is now known */
//printMsg(lvlError, "foo");
/* do we know the result of the same function called
with the evaluated argument? */
Expr res = state.normalForms.get(makeCall(fun, arg2));
if (res) { /* woohoo! */
printMsg(lvlError, "dingdong");
state.normalForms.set(e, res);
return res;
}
}
#endif
assert(activeCallsCount < maxActiveCalls);
activeCalls[activeCallsCount++] = e;
//state.normalForms.set(e, makeBlackHole());
try {
nf = evalExpr2(state, e);
}
catch (ShortCircuit & exception) {
//printMsg(lvlError, "catch!");
Expr arg2 = state.normalForms.get(arg);
if (arg2) { /* the evaluated argument is now known */
/* do we know the result of the same function called
with the evaluated argument? */
Expr res = state.normalForms.get(makeCall(fun, arg2));
if (res) { /* woohoo! */
//printMsg(lvlError, "woohoo!");
//printMsg(lvlError, format("woohoo! %1% %2% %3% %4%") % fun % arg % arg2 % res);
activeCallsCount--;
state.normalForms.set(e, res);
maybeShortCircuit(state, e, res);
return res;
}
}
activeCallsCount--;
state.normalForms.remove(e);
throw; /* not for us */
}
catch (...) {
activeCallsCount--;
state.normalForms.remove(e);
throw;
}
activeCallsCount--;
state.normalForms.set(e, nf);
Expr arg2 = state.normalForms.get(arg);
if (arg2) state.normalForms.set(makeCall(fun, arg2), nf);
maybeShortCircuit(state, e, nf);
return nf;
}
else {
/* Otherwise, evaluate and memoize. */
//state.normalForms.set(e, makeBlackHole());
try {
nf = evalExpr2(state, e);
} catch (...) {
state.normalForms.remove(e);
throw;
}
state.normalForms.set(e, nf);
maybeShortCircuit(state, e, nf);
return nf;
}
state.normalForms.set(e, nf);
return nf;
}
@ -884,6 +985,7 @@ void printEvalStats(EvalState & state)
{
char x;
bool showStats = getEnv("NIX_SHOW_STATS", "0") != "0";
printMsg(lvlError, format("FNORD %1%") % fnord);
printMsg(showStats ? lvlInfo : lvlDebug,
format("evaluated %1% expressions, %2% cache hits, %3%%% efficiency, used %4% ATerm bytes, used %5% bytes of stack space")
% state.nrEvaluated % state.nrCached