1
1
Fork 0
mirror of https://github.com/NixOS/nix.git synced 2025-11-09 03:56:01 +01:00

libexpr: store ExprLambda data in Expr::alloc

This commit is contained in:
Taeer Bar-Yam 2025-10-27 21:29:17 +01:00
parent 4a2fb18ba0
commit 3a3c062982
10 changed files with 88 additions and 61 deletions

View file

@ -112,7 +112,7 @@ TEST_F(ValuePrintingTests, vLambda)
auto body = ExprInt(0); auto body = ExprInt(0);
auto formals = Formals{}; auto formals = Formals{};
ExprLambda eLambda(posIdx, createSymbol("a"), &formals, &body); ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
Value vLambda; Value vLambda;
vLambda.mkLambda(&env, &eLambda); vLambda.mkLambda(&env, &eLambda);
@ -502,7 +502,7 @@ TEST_F(ValuePrintingTests, ansiColorsLambda)
auto body = ExprInt(0); auto body = ExprInt(0);
auto formals = Formals{}; auto formals = Formals{};
ExprLambda eLambda(posIdx, createSymbol("a"), &formals, &body); ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
Value vLambda; Value vLambda;
vLambda.mkLambda(&env, &eLambda); vLambda.mkLambda(&env, &eLambda);

View file

@ -1496,7 +1496,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
ExprLambda & lambda(*vCur.lambda().fun); ExprLambda & lambda(*vCur.lambda().fun);
auto size = (!lambda.arg ? 0 : 1) + (lambda.hasFormals() ? lambda.formals->formals.size() : 0); auto size = (!lambda.arg ? 0 : 1) + lambda.nFormals;
Env & env2(mem.allocEnv(size)); Env & env2(mem.allocEnv(size));
env2.up = vCur.lambda().env; env2.up = vCur.lambda().env;
@ -1520,7 +1520,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
there is no matching actual argument but the formal there is no matching actual argument but the formal
argument has a default, use the default. */ argument has a default, use the default. */
size_t attrsUsed = 0; size_t attrsUsed = 0;
for (auto & i : lambda.formals->formals) { for (auto & i : lambda.getFormals()) {
auto j = args[0]->attrs()->get(i.name); auto j = args[0]->attrs()->get(i.name);
if (!j) { if (!j) {
if (!i.def) { if (!i.def) {
@ -1542,13 +1542,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
/* Check that each actual argument is listed as a formal /* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */ argument (unless the attribute match specifies a `...'). */
if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs()->size()) { if (!lambda.ellipsis && attrsUsed != args[0]->attrs()->size()) {
/* Nope, so show the first unexpected argument to the /* Nope, so show the first unexpected argument to the
user. */ user. */
for (auto & i : *args[0]->attrs()) for (auto & i : *args[0]->attrs())
if (!lambda.formals->has(i.name)) { if (!lambda.hasFormal(i.name)) {
StringSet formalNames; StringSet formalNames;
for (auto & formal : lambda.formals->formals) for (auto & formal : lambda.getFormals())
formalNames.insert(std::string(symbols[formal.name])); formalNames.insert(std::string(symbols[formal.name]));
auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]); auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]);
error<TypeError>( error<TypeError>(
@ -1752,9 +1752,9 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
return; return;
} }
auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->formals->formals.size()), args.size())); auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->nFormals), args.size()));
if (fun.lambda().fun->formals->ellipsis) { if (fun.lambda().fun->ellipsis) {
// If the formals have an ellipsis (eg the function accepts extra args) pass // If the formals have an ellipsis (eg the function accepts extra args) pass
// all available automatic arguments (which includes arguments specified on // all available automatic arguments (which includes arguments specified on
// the command line via --arg/--argstr) // the command line via --arg/--argstr)
@ -1762,7 +1762,7 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
attrs.insert(v); attrs.insert(v);
} else { } else {
// Otherwise, only pass the arguments that the function accepts // Otherwise, only pass the arguments that the function accepts
for (auto & i : fun.lambda().fun->formals->formals) { for (auto & i : fun.lambda().fun->getFormals()) {
auto j = args.get(i.name); auto j = args.get(i.name);
if (j) { if (j) {
attrs.insert(*j); attrs.insert(*j);

View file

@ -481,16 +481,6 @@ struct Formals
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; }); formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg; return it != formals.end() && it->name == arg;
} }
std::vector<Formal> lexicographicOrder(const SymbolTable & symbols) const
{
std::vector<Formal> result(formals.begin(), formals.end());
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
std::string_view sa = symbols[a.name], sb = symbols[b.name];
return sa < sb;
});
return result;
}
}; };
struct ExprLambda : Expr struct ExprLambda : Expr
@ -498,21 +488,42 @@ struct ExprLambda : Expr
PosIdx pos; PosIdx pos;
Symbol name; Symbol name;
Symbol arg; Symbol arg;
Formals * formals;
bool ellipsis;
uint16_t nFormals;
Formal * formalsStart;
Expr * body; Expr * body;
DocComment docComment; DocComment docComment;
ExprLambda(PosIdx pos, Symbol arg, Formals * formals, Expr * body) ExprLambda(
std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, const Formals & formals, Expr * body)
: pos(pos) : pos(pos)
, arg(arg) , arg(arg)
, formals(formals) , ellipsis(formals.ellipsis)
, body(body) {}; , nFormals(formals.formals.size())
, formalsStart(alloc.allocate_object<Formal>(nFormals))
ExprLambda(PosIdx pos, Formals * formals, Expr * body)
: pos(pos)
, formals(formals)
, body(body) , body(body)
{ {
std::ranges::copy(formals.formals, formalsStart);
};
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, Expr * body)
: pos(pos)
, arg(arg)
, nFormals(0)
, formalsStart(nullptr)
, body(body) {};
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Formals formals, Expr * body)
: ExprLambda(alloc, pos, Symbol(), formals, body) {};
bool hasFormal(Symbol arg) const
{
auto formals = getFormals();
auto it = std::lower_bound(
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg;
} }
void setName(Symbol name) override; void setName(Symbol name) override;
@ -520,7 +531,17 @@ struct ExprLambda : Expr
inline bool hasFormals() const inline bool hasFormals() const
{ {
return formals != nullptr; return nFormals > 0;
}
std::vector<Formal> getFormalsLexicographic(const SymbolTable & symbols) const
{
std::vector<Formal> result(getFormals().begin(), getFormals().end());
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
std::string_view sa = symbols[a.name], sb = symbols[b.name];
return sa < sb;
});
return result;
} }
PosIdx getPos() const override PosIdx getPos() const override
@ -528,6 +549,11 @@ struct ExprLambda : Expr
return pos; return pos;
} }
std::span<Formal> getFormals() const
{
return {formalsStart, nFormals};
}
virtual void setDocComment(DocComment docComment) override; virtual void setDocComment(DocComment docComment) override;
COMMON_METHODS COMMON_METHODS
}; };

View file

@ -93,7 +93,7 @@ struct ParserState
void addAttr( void addAttr(
ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc); ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc);
void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def); void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def);
Formals * validateFormals(Formals * formals, PosIdx pos = noPos, Symbol arg = {}); void validateFormals(Formals & formals, PosIdx pos = noPos, Symbol arg = {});
Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es); Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es);
PosIdx at(const ParserLocation & loc); PosIdx at(const ParserLocation & loc);
}; };
@ -213,17 +213,17 @@ ParserState::addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symb
} }
} }
inline Formals * ParserState::validateFormals(Formals * formals, PosIdx pos, Symbol arg) inline void ParserState::validateFormals(Formals & formals, PosIdx pos, Symbol arg)
{ {
std::sort(formals->formals.begin(), formals->formals.end(), [](const auto & a, const auto & b) { std::sort(formals.formals.begin(), formals.formals.end(), [](const auto & a, const auto & b) {
return std::tie(a.name, a.pos) < std::tie(b.name, b.pos); return std::tie(a.name, a.pos) < std::tie(b.name, b.pos);
}); });
std::optional<std::pair<Symbol, PosIdx>> duplicate; std::optional<std::pair<Symbol, PosIdx>> duplicate;
for (size_t i = 0; i + 1 < formals->formals.size(); i++) { for (size_t i = 0; i + 1 < formals.formals.size(); i++) {
if (formals->formals[i].name != formals->formals[i + 1].name) if (formals.formals[i].name != formals.formals[i + 1].name)
continue; continue;
std::pair thisDup{formals->formals[i].name, formals->formals[i + 1].pos}; std::pair thisDup{formals.formals[i].name, formals.formals[i + 1].pos};
duplicate = std::min(thisDup, duplicate.value_or(thisDup)); duplicate = std::min(thisDup, duplicate.value_or(thisDup));
} }
if (duplicate) if (duplicate)
@ -231,11 +231,9 @@ inline Formals * ParserState::validateFormals(Formals * formals, PosIdx pos, Sym
{.msg = HintFmt("duplicate formal function argument '%1%'", symbols[duplicate->first]), {.msg = HintFmt("duplicate formal function argument '%1%'", symbols[duplicate->first]),
.pos = positions[duplicate->second]}); .pos = positions[duplicate->second]});
if (arg && formals->has(arg)) if (arg && formals.has(arg))
throw ParseError( throw ParseError(
{.msg = HintFmt("duplicate formal function argument '%1%'", symbols[arg]), .pos = positions[pos]}); {.msg = HintFmt("duplicate formal function argument '%1%'", symbols[arg]), .pos = positions[pos]});
return formals;
} }
inline Expr * inline Expr *

View file

@ -160,7 +160,7 @@ void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
// the natural Symbol ordering is by creation time, which can lead to the // the natural Symbol ordering is by creation time, which can lead to the
// same expression being printed in two different ways depending on its // same expression being printed in two different ways depending on its
// context. always use lexicographic ordering to avoid this. // context. always use lexicographic ordering to avoid this.
for (auto & i : formals->lexicographicOrder(symbols)) { for (auto & i : getFormalsLexicographic(symbols)) {
if (first) if (first)
first = false; first = false;
else else
@ -171,7 +171,7 @@ void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
i.def->show(symbols, str); i.def->show(symbols, str);
} }
} }
if (formals->ellipsis) { if (ellipsis) {
if (!first) if (!first)
str << ", "; str << ", ";
str << "..."; str << "...";
@ -451,8 +451,7 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
if (es.debugRepl) if (es.debugRepl)
es.exprEnvs.insert(std::make_pair(this, env)); es.exprEnvs.insert(std::make_pair(this, env));
auto newEnv = auto newEnv = std::make_shared<StaticEnv>(nullptr, env, nFormals + (!arg ? 0 : 1));
std::make_shared<StaticEnv>(nullptr, env, (hasFormals() ? formals->formals.size() : 0) + (!arg ? 0 : 1));
Displacement displ = 0; Displacement displ = 0;
@ -460,12 +459,12 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
newEnv->vars.emplace_back(arg, displ++); newEnv->vars.emplace_back(arg, displ++);
if (hasFormals()) { if (hasFormals()) {
for (auto & i : formals->formals) for (auto & i : getFormals())
newEnv->vars.emplace_back(i.name, displ++); newEnv->vars.emplace_back(i.name, displ++);
newEnv->sort(); newEnv->sort();
for (auto & i : formals->formals) for (auto & i : getFormals())
if (i.def) if (i.def)
i.def->bindVars(es, newEnv); i.def->bindVars(es, newEnv);
} }

View file

@ -131,7 +131,7 @@ static Expr * makeCall(PosIdx pos, Expr * fn, Expr * arg) {
%type <nix::Expr *> expr_pipe_from expr_pipe_into %type <nix::Expr *> expr_pipe_from expr_pipe_into
%type <std::vector<Expr *>> list %type <std::vector<Expr *>> list
%type <nix::ExprAttrs *> binds binds1 %type <nix::ExprAttrs *> binds binds1
%type <nix::Formals *> formals formal_set %type <nix::Formals> formals formal_set
%type <nix::Formal> formal %type <nix::Formal> formal
%type <std::vector<nix::AttrName>> attrpath %type <std::vector<nix::AttrName>> attrpath
%type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs %type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs
@ -179,26 +179,30 @@ expr: expr_function;
expr_function expr_function
: ID ':' expr_function : ID ':' expr_function
{ auto me = new ExprLambda(CUR_POS, state->symbols.create($1), 0, $3); { auto me = new ExprLambda(state->alloc, CUR_POS, state->symbols.create($1), $3);
$$ = me; $$ = me;
SET_DOC_POS(me, @1); SET_DOC_POS(me, @1);
} }
| formal_set ':' expr_function[body] | formal_set ':' expr_function[body]
{ auto me = new ExprLambda(CUR_POS, state->validateFormals($formal_set), $body); {
state->validateFormals($formal_set);
auto me = new ExprLambda(state->alloc, CUR_POS, std::move($formal_set), $body);
$$ = me; $$ = me;
SET_DOC_POS(me, @1); SET_DOC_POS(me, @1);
} }
| formal_set '@' ID ':' expr_function[body] | formal_set '@' ID ':' expr_function[body]
{ {
auto arg = state->symbols.create($ID); auto arg = state->symbols.create($ID);
auto me = new ExprLambda(CUR_POS, arg, state->validateFormals($formal_set, CUR_POS, arg), $body); state->validateFormals($formal_set, CUR_POS, arg);
auto me = new ExprLambda(state->alloc, CUR_POS, arg, std::move($formal_set), $body);
$$ = me; $$ = me;
SET_DOC_POS(me, @1); SET_DOC_POS(me, @1);
} }
| ID '@' formal_set ':' expr_function[body] | ID '@' formal_set ':' expr_function[body]
{ {
auto arg = state->symbols.create($ID); auto arg = state->symbols.create($ID);
auto me = new ExprLambda(CUR_POS, arg, state->validateFormals($formal_set, CUR_POS, arg), $body); state->validateFormals($formal_set, CUR_POS, arg);
auto me = new ExprLambda(state->alloc, CUR_POS, arg, std::move($formal_set), $body);
$$ = me; $$ = me;
SET_DOC_POS(me, @1); SET_DOC_POS(me, @1);
} }
@ -490,18 +494,18 @@ list
; ;
formal_set formal_set
: '{' formals ',' ELLIPSIS '}' { $$ = $formals; $$->ellipsis = true; } : '{' formals ',' ELLIPSIS '}' { $$ = std::move($formals); $$.ellipsis = true; }
| '{' ELLIPSIS '}' { $$ = new Formals; $$->ellipsis = true; } | '{' ELLIPSIS '}' { $$.ellipsis = true; }
| '{' formals ',' '}' { $$ = $formals; $$->ellipsis = false; } | '{' formals ',' '}' { $$ = std::move($formals); $$.ellipsis = false; }
| '{' formals '}' { $$ = $formals; $$->ellipsis = false; } | '{' formals '}' { $$ = std::move($formals); $$.ellipsis = false; }
| '{' '}' { $$ = new Formals; $$->ellipsis = false; } | '{' '}' { $$.ellipsis = false; }
; ;
formals formals
: formals[accum] ',' formal : formals[accum] ',' formal
{ $$ = $accum; $$->formals.emplace_back(std::move($formal)); } { $$ = std::move($accum); $$.formals.emplace_back(std::move($formal)); }
| formal | formal
{ $$ = new Formals; $$->formals.emplace_back(std::move($formal)); } { $$.formals.emplace_back(std::move($formal)); }
; ;
formal formal

View file

@ -3368,7 +3368,7 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
return; return;
} }
const auto & formals = args[0]->lambda().fun->formals->formals; const auto & formals = args[0]->lambda().fun->getFormals();
auto attrs = state.buildBindings(formals.size()); auto attrs = state.buildBindings(formals.size());
for (auto & i : formals) for (auto & i : formals)
attrs.insert(i.name, state.getBool(i.def), i.pos); attrs.insert(i.name, state.getBool(i.def), i.pos);

View file

@ -149,10 +149,10 @@ static void printValueAsXML(
XMLAttrs attrs; XMLAttrs attrs;
if (v.lambda().fun->arg) if (v.lambda().fun->arg)
attrs["name"] = state.symbols[v.lambda().fun->arg]; attrs["name"] = state.symbols[v.lambda().fun->arg];
if (v.lambda().fun->formals->ellipsis) if (v.lambda().fun->ellipsis)
attrs["ellipsis"] = "1"; attrs["ellipsis"] = "1";
XMLOpenElement _(doc, "attrspat", attrs); XMLOpenElement _(doc, "attrspat", attrs);
for (auto & i : v.lambda().fun->formals->lexicographicOrder(state.symbols)) for (auto & i : v.lambda().fun->getFormalsLexicographic(state.symbols))
doc.writeEmptyElement("attr", singletonAttrs("name", state.symbols[i.name])); doc.writeEmptyElement("attr", singletonAttrs("name", state.symbols[i.name]));
} else } else
doc.writeEmptyElement("varpat", singletonAttrs("name", state.symbols[v.lambda().fun->arg])); doc.writeEmptyElement("varpat", singletonAttrs("name", state.symbols[v.lambda().fun->arg]));

View file

@ -282,7 +282,7 @@ static Flake readFlake(
expectType(state, nFunction, *outputs->value, outputs->pos); expectType(state, nFunction, *outputs->value, outputs->pos);
if (outputs->value->isLambda() && outputs->value->lambda().fun->hasFormals()) { if (outputs->value->isLambda() && outputs->value->lambda().fun->hasFormals()) {
for (auto & formal : outputs->value->lambda().fun->formals->formals) { for (auto & formal : outputs->value->lambda().fun->getFormals()) {
if (formal.name != state.s.self) if (formal.name != state.s.self)
flake.inputs.emplace( flake.inputs.emplace(
state.symbols[formal.name], state.symbols[formal.name],

View file

@ -416,7 +416,7 @@ static void main_nix_build(int argc, char ** argv)
} }
bool add = false; bool add = false;
if (v.type() == nFunction && v.lambda().fun->hasFormals()) { if (v.type() == nFunction && v.lambda().fun->hasFormals()) {
for (auto & i : v.lambda().fun->formals->formals) { for (auto & i : v.lambda().fun->getFormals()) {
if (state->symbols[i.name] == "inNixShell") { if (state->symbols[i.name] == "inNixShell") {
add = true; add = true;
break; break;