1
1
Fork 0
mirror of https://github.com/NixOS/nix.git synced 2025-11-08 19:46:02 +01:00

safer interface for ExprLambda's formals

This commit is contained in:
Taeer Bar-Yam 2025-10-31 16:21:50 +01:00
parent e43888890f
commit 34f780d747
12 changed files with 105 additions and 82 deletions

View file

@ -771,7 +771,7 @@ TEST_F(PrimOpTest, derivation)
ASSERT_EQ(v.type(), nFunction);
ASSERT_TRUE(v.isLambda());
ASSERT_NE(v.lambda().fun, nullptr);
ASSERT_TRUE(v.lambda().fun->hasFormals);
ASSERT_TRUE(v.lambda().fun->getFormals());
}
TEST_F(PrimOpTest, currentTime)

View file

@ -110,9 +110,8 @@ TEST_F(ValuePrintingTests, vLambda)
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
auto posIdx = state.positions.add(origin, 0);
auto body = ExprInt(0);
auto formals = Formals{};
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
Value vLambda;
vLambda.mkLambda(&env, &eLambda);
@ -500,9 +499,8 @@ TEST_F(ValuePrintingTests, ansiColorsLambda)
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
auto posIdx = state.positions.add(origin, 0);
auto body = ExprInt(0);
auto formals = Formals{};
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), formals, &body);
ExprLambda eLambda(state.mem.exprs.alloc, posIdx, createSymbol("a"), &body);
Value vLambda;
vLambda.mkLambda(&env, &eLambda);

View file

@ -1496,15 +1496,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
ExprLambda & lambda(*vCur.lambda().fun);
auto size = (!lambda.arg ? 0 : 1) + (lambda.hasFormals ? lambda.getFormals().size() : 0);
auto size = (!lambda.arg ? 0 : 1) + (lambda.getFormals() ? lambda.getFormals()->formals.size() : 0);
Env & env2(mem.allocEnv(size));
env2.up = vCur.lambda().env;
Displacement displ = 0;
if (!lambda.hasFormals)
env2.values[displ++] = args[0];
else {
if (auto formals = lambda.getFormals()) {
try {
forceAttrs(*args[0], lambda.pos, "while evaluating the value passed for the lambda argument");
} catch (Error & e) {
@ -1520,7 +1518,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
there is no matching actual argument but the formal
argument has a default, use the default. */
size_t attrsUsed = 0;
for (auto & i : lambda.getFormals()) {
for (auto & i : formals->formals) {
auto j = args[0]->attrs()->get(i.name);
if (!j) {
if (!i.def) {
@ -1542,13 +1540,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
/* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */
if (!lambda.ellipsis && attrsUsed != args[0]->attrs()->size()) {
if (!formals->ellipsis && attrsUsed != args[0]->attrs()->size()) {
/* Nope, so show the first unexpected argument to the
user. */
for (auto & i : *args[0]->attrs())
if (!lambda.hasFormal(i.name)) {
if (!formals->has(i.name)) {
StringSet formalNames;
for (auto & formal : lambda.getFormals())
for (auto & formal : formals->formals)
formalNames.insert(std::string(symbols[formal.name]));
auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]);
error<TypeError>(
@ -1563,6 +1561,8 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
}
unreachable();
}
} else {
env2.values[displ++] = args[0];
}
nrFunctionCalls++;
@ -1747,14 +1747,15 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
}
}
if (!fun.isLambda() || !fun.lambda().fun->hasFormals) {
if (!fun.isLambda() || !fun.lambda().fun->getFormals()) {
res = fun;
return;
}
auto formals = fun.lambda().fun->getFormals();
auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->nFormals), args.size()));
auto attrs = buildBindings(std::max(static_cast<uint32_t>(formals->formals.size()), args.size()));
if (fun.lambda().fun->ellipsis) {
if (formals->ellipsis) {
// If the formals have an ellipsis (eg the function accepts extra args) pass
// all available automatic arguments (which includes arguments specified on
// the command line via --arg/--argstr)
@ -1762,7 +1763,7 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
attrs.insert(v);
} else {
// Otherwise, only pass the arguments that the function accepts
for (auto & i : fun.lambda().fun->getFormals()) {
for (auto & i : formals->formals) {
auto j = args.get(i.name);
if (j) {
attrs.insert(*j);

View file

@ -466,7 +466,7 @@ struct Formal
Expr * def;
};
struct Formals
struct FormalsBuilder
{
typedef std::vector<Formal> Formals_;
/**
@ -483,26 +483,67 @@ struct Formals
}
};
struct Formals
{
std::span<Formal> formals;
bool ellipsis;
Formals(std::span<Formal> formals, bool ellipsis)
: formals(formals)
, ellipsis(ellipsis) {};
bool has(Symbol arg) const
{
auto it = std::lower_bound(
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg;
}
std::vector<Formal> lexicographicOrder(const SymbolTable & symbols) const
{
std::vector<Formal> result(formals.begin(), formals.end());
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
std::string_view sa = symbols[a.name], sb = symbols[b.name];
return sa < sb;
});
return result;
}
};
struct ExprLambda : Expr
{
PosIdx pos;
Symbol name;
Symbol arg;
bool ellipsis;
private:
bool hasFormals;
bool ellipsis;
uint16_t nFormals;
Formal * formalsStart;
public:
std::optional<Formals> getFormals() const
{
if (hasFormals)
return Formals{{formalsStart, nFormals}, ellipsis};
else
return std::nullopt;
}
Expr * body;
DocComment docComment;
ExprLambda(
std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Symbol arg, const Formals & formals, Expr * body)
std::pmr::polymorphic_allocator<char> & alloc,
PosIdx pos,
Symbol arg,
const FormalsBuilder & formals,
Expr * body)
: pos(pos)
, arg(arg)
, ellipsis(formals.ellipsis)
, hasFormals(true)
, ellipsis(formals.ellipsis)
, nFormals(formals.formals.size())
, formalsStart(alloc.allocate_object<Formal>(nFormals))
, body(body)
@ -514,44 +555,22 @@ struct ExprLambda : Expr
: pos(pos)
, arg(arg)
, hasFormals(false)
, ellipsis(false)
, nFormals(0)
, formalsStart(nullptr)
, body(body) {};
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, Formals formals, Expr * body)
ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, FormalsBuilder formals, Expr * body)
: ExprLambda(alloc, pos, Symbol(), formals, body) {};
bool hasFormal(Symbol arg) const
{
auto formals = getFormals();
auto it = std::lower_bound(
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg;
}
void setName(Symbol name) override;
std::string showNamePos(const EvalState & state) const;
std::vector<Formal> getFormalsLexicographic(const SymbolTable & symbols) const
{
std::vector<Formal> result(getFormals().begin(), getFormals().end());
std::sort(result.begin(), result.end(), [&](const Formal & a, const Formal & b) {
std::string_view sa = symbols[a.name], sb = symbols[b.name];
return sa < sb;
});
return result;
}
PosIdx getPos() const override
{
return pos;
}
std::span<Formal> getFormals() const
{
assert(hasFormals);
return {formalsStart, nFormals};
}
virtual void setDocComment(DocComment docComment) override;
COMMON_METHODS
};

View file

@ -93,7 +93,7 @@ struct ParserState
void addAttr(
ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc);
void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def);
void validateFormals(Formals & formals, PosIdx pos = noPos, Symbol arg = {});
void validateFormals(FormalsBuilder & formals, PosIdx pos = noPos, Symbol arg = {});
Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es);
PosIdx at(const ParserLocation & loc);
};
@ -213,7 +213,7 @@ ParserState::addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symb
}
}
inline void ParserState::validateFormals(Formals & formals, PosIdx pos, Symbol arg)
inline void ParserState::validateFormals(FormalsBuilder & formals, PosIdx pos, Symbol arg)
{
std::sort(formals.formals.begin(), formals.formals.end(), [](const auto & a, const auto & b) {
return std::tie(a.name, a.pos) < std::tie(b.name, b.pos);

View file

@ -154,13 +154,13 @@ void ExprList::show(const SymbolTable & symbols, std::ostream & str) const
void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
{
str << "(";
if (hasFormals) {
if (auto formals = getFormals()) {
str << "{ ";
bool first = true;
// the natural Symbol ordering is by creation time, which can lead to the
// same expression being printed in two different ways depending on its
// context. always use lexicographic ordering to avoid this.
for (auto & i : getFormalsLexicographic(symbols)) {
for (auto & i : formals->lexicographicOrder(symbols)) {
if (first)
first = false;
else
@ -451,20 +451,21 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
if (es.debugRepl)
es.exprEnvs.insert(std::make_pair(this, env));
auto newEnv = std::make_shared<StaticEnv>(nullptr, env, (hasFormals ? getFormals().size() : 0) + (!arg ? 0 : 1));
auto newEnv =
std::make_shared<StaticEnv>(nullptr, env, (getFormals() ? getFormals()->formals.size() : 0) + (!arg ? 0 : 1));
Displacement displ = 0;
if (arg)
newEnv->vars.emplace_back(arg, displ++);
if (hasFormals) {
for (auto & i : getFormals())
if (auto formals = getFormals()) {
for (auto & i : formals->formals)
newEnv->vars.emplace_back(i.name, displ++);
newEnv->sort();
for (auto & i : getFormals())
for (auto & i : formals->formals)
if (i.def)
i.def->bindVars(es, newEnv);
}

View file

@ -131,7 +131,7 @@ static Expr * makeCall(PosIdx pos, Expr * fn, Expr * arg) {
%type <nix::Expr *> expr_pipe_from expr_pipe_into
%type <std::vector<Expr *>> list
%type <nix::ExprAttrs *> binds binds1
%type <nix::Formals> formals formal_set
%type <nix::FormalsBuilder> formals formal_set
%type <nix::Formal> formal
%type <std::vector<nix::AttrName>> attrpath
%type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs

View file

@ -3363,14 +3363,9 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
if (!args[0]->isLambda())
state.error<TypeError>("'functionArgs' requires a function").atPos(pos).debugThrow();
if (!args[0]->lambda().fun->hasFormals) {
v.mkAttrs(&Bindings::emptyBindings);
return;
}
const auto & formals = args[0]->lambda().fun->getFormals();
auto attrs = state.buildBindings(formals.size());
for (auto & i : formals)
if (const auto & formals = args[0]->lambda().fun->getFormals()) {
auto attrs = state.buildBindings(formals->formals.size());
for (auto & i : formals->formals)
attrs.insert(i.name, state.getBool(i.def), i.pos);
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
@ -3378,6 +3373,10 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
.*/
v.mkAttrs(attrs.alreadySorted());
} else {
v.mkAttrs(&Bindings::emptyBindings);
return;
}
}
static RegisterPrimOp primop_functionArgs({

View file

@ -145,14 +145,14 @@ static void printValueAsXML(
posToXML(state, xmlAttrs, state.positions[v.lambda().fun->pos]);
XMLOpenElement _(doc, "function", xmlAttrs);
if (v.lambda().fun->hasFormals) {
if (auto formals = v.lambda().fun->getFormals()) {
XMLAttrs attrs;
if (v.lambda().fun->arg)
attrs["name"] = state.symbols[v.lambda().fun->arg];
if (v.lambda().fun->ellipsis)
if (formals->ellipsis)
attrs["ellipsis"] = "1";
XMLOpenElement _(doc, "attrspat", attrs);
for (auto & i : v.lambda().fun->getFormalsLexicographic(state.symbols))
for (auto & i : formals->lexicographicOrder(state.symbols))
doc.writeEmptyElement("attr", singletonAttrs("name", state.symbols[i.name]));
} else
doc.writeEmptyElement("varpat", singletonAttrs("name", state.symbols[v.lambda().fun->arg]));

View file

@ -281,12 +281,15 @@ static Flake readFlake(
if (auto outputs = vInfo.attrs()->get(sOutputs)) {
expectType(state, nFunction, *outputs->value, outputs->pos);
if (outputs->value->isLambda() && outputs->value->lambda().fun->hasFormals) {
for (auto & formal : outputs->value->lambda().fun->getFormals()) {
if (outputs->value->isLambda()) {
if (auto formals = outputs->value->lambda().fun->getFormals()) {
for (auto & formal : formals->formals) {
if (formal.name != state.s.self)
flake.inputs.emplace(
state.symbols[formal.name],
FlakeInput{.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
FlakeInput{
.ref = parseFlakeRef(state.fetchSettings, std::string(state.symbols[formal.name]))});
}
}
}

View file

@ -468,7 +468,7 @@ struct CmdFlakeCheck : FlakeCommand
if (!v.isLambda()) {
throw Error("overlay is not a function, but %s instead", showType(v));
}
if (v.lambda().fun->hasFormals || !argHasName(v.lambda().fun->arg, "final"))
if (v.lambda().fun->getFormals() || !argHasName(v.lambda().fun->arg, "final"))
throw Error("overlay does not take an argument named 'final'");
// FIXME: if we have a 'nixpkgs' input, use it to
// evaluate the overlay.

View file

@ -415,14 +415,16 @@ static void main_nix_build(int argc, char ** argv)
return false;
}
bool add = false;
if (v.type() == nFunction && v.lambda().fun->hasFormals) {
for (auto & i : v.lambda().fun->getFormals()) {
if (v.type() == nFunction) {
if (auto formals = v.lambda().fun->getFormals()) {
for (auto & i : formals->formals) {
if (state->symbols[i.name] == "inNixShell") {
add = true;
break;
}
}
}
}
return add;
};