Skip to content

Commit

Permalink
Limited support for constants in procs.
Browse files Browse the repository at this point in the history
Proc-scoped constants cannot reference template parameters yet.

PiperOrigin-RevId: 677938979
  • Loading branch information
erinzmoore authored and copybara-github committed Sep 23, 2024
1 parent 83ce024 commit 4bb43f7
Show file tree
Hide file tree
Showing 11 changed files with 539 additions and 41 deletions.
65 changes: 38 additions & 27 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,33 @@ DocRef Fmt(const ConstAssert& n, const Comments& comments, DocArena& arena) {
});
}

static DocRef Fmt(const ConstantDef& n, const Comments& comments,
DocArena& arena) {
std::vector<DocRef> leader_pieces;
if (n.is_public()) {
leader_pieces.push_back(arena.Make(Keyword::kPub));
leader_pieces.push_back(arena.break1());
}
leader_pieces.push_back(arena.Make(Keyword::kConst));
leader_pieces.push_back(arena.break1());
leader_pieces.push_back(arena.MakeText(n.identifier()));
if (n.type_annotation() != nullptr) {
leader_pieces.push_back(arena.colon());
leader_pieces.push_back(arena.space());
leader_pieces.push_back(Fmt(*n.type_annotation(), comments, arena));
}
leader_pieces.push_back(arena.break1());
leader_pieces.push_back(arena.equals());
leader_pieces.push_back(arena.space());

DocRef lhs = ConcatNGroup(arena, leader_pieces);
DocRef rhs = ConcatNGroup(arena, {
Fmt(*n.value(), comments, arena),
arena.semi(),
});
return arena.MakeConcat(lhs, rhs);
}

DocRef Fmt(const TupleIndex& n, const Comments& comments, DocArena& arena) {
std::vector<DocRef> pieces;
if (WeakerThan(n.lhs()->GetPrecedence(), n.GetPrecedence())) {
Expand Down Expand Up @@ -1823,6 +1850,17 @@ static DocRef Fmt(const Proc& n, const Comments& comments, DocArena& arena) {
stmt_pieces.push_back(arena.hard_line());
last_stmt_limit = n->span().limit();
},
[&](const ConstantDef* n) {
if (std::optional<DocRef> maybe_doc =
EmitCommentsBetween(last_stmt_limit, n->span().start(),
comments, arena, nullptr)) {
stmt_pieces.push_back(
arena.MakeConcat(maybe_doc.value(), arena.hard_line()));
}
stmt_pieces.push_back(Fmt(*n, comments, arena));
stmt_pieces.push_back(arena.hard_line());
last_stmt_limit = n->span().limit();
},
[&](const TypeAlias* n) {
if (std::optional<DocRef> maybe_doc =
EmitCommentsBetween(last_stmt_limit, n->span().start(),
Expand Down Expand Up @@ -2081,33 +2119,6 @@ static DocRef Fmt(const StructDef& n, const Comments& comments,
return JoinWithAttr(attr, ConcatNGroup(arena, pieces), arena);
}

static DocRef Fmt(const ConstantDef& n, const Comments& comments,
DocArena& arena) {
std::vector<DocRef> leader_pieces;
if (n.is_public()) {
leader_pieces.push_back(arena.Make(Keyword::kPub));
leader_pieces.push_back(arena.break1());
}
leader_pieces.push_back(arena.Make(Keyword::kConst));
leader_pieces.push_back(arena.break1());
leader_pieces.push_back(arena.MakeText(n.identifier()));
if (n.type_annotation() != nullptr) {
leader_pieces.push_back(arena.colon());
leader_pieces.push_back(arena.space());
leader_pieces.push_back(Fmt(*n.type_annotation(), comments, arena));
}
leader_pieces.push_back(arena.break1());
leader_pieces.push_back(arena.equals());
leader_pieces.push_back(arena.space());

DocRef lhs = ConcatNGroup(arena, leader_pieces);
DocRef rhs = ConcatNGroup(arena, {
Fmt(*n.value(), comments, arena),
arena.semi(),
});
return arena.MakeConcat(lhs, rhs);
}

static DocRef FmtEnumMember(const EnumMember& n, const Comments& comments,
DocArena& arena) {
return ConcatNGroup(
Expand Down
17 changes: 17 additions & 0 deletions xls/dslx/fmt/ast_fmt_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,23 @@ TEST_F(ModuleFmtTest, SimpleProc) {
)");
}

TEST_F(ModuleFmtTest, SimpleProcWithConstant) {
Run(
R"(pub proc p {
// My constant.
const MY_CONST = u32:8;
// My second constant.
const ANOTHER_CONST = "second";
config() { () }
init { () }
next(state: ()) { () }
}
)");
}

TEST_F(ModuleFmtTest, SimpleProcWithComments) {
Run(
R"(// Proc-level comment.
Expand Down
5 changes: 5 additions & 0 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2485,6 +2485,11 @@ absl::StatusOr<T*> Parser::ParseProcLike(bool is_public,
XLS_ASSIGN_OR_RETURN(TypeAlias * type_alias,
ParseTypeAlias(/*is_public=*/false, proc_bindings));
proc_like_body.stmts.push_back(type_alias);
} else if (peek->IsKeyword(Keyword::kConst)) {
XLS_ASSIGN_OR_RETURN(
ConstantDef * constant,
ParseConstantDef(/*is_public=*/false, proc_bindings));
proc_like_body.stmts.push_back(constant);
} else if (peek->IsIdentifier("config")) {
XLS_RETURN_IF_ERROR(check_not_yet_specified(proc_like_body.config, peek));

Expand Down
118 changes: 118 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,124 @@ TEST_F(ParserTest, ParseSimpleProc) {
EXPECT_EQ(p->ToString(), text);
}

TEST_F(ParserTest, ParseProcWithConst) {
RoundTrip(R"(proc simple {
const MAX_X = u32:10;
x: u32;
config() {
()
}
init {
u32:0
}
next(addend: u32) {
if x > MAX_X { x } else { x + addend }
}
})");
}

TEST_F(ParserTest, ParseParametricProcWithConstant) {
RoundTrip(R"(proc MyProc<X: u32> {
const INC = u32:20;
c: chan<u32> in;
config(c: chan<u32> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + INC
}
})");
}

TEST_F(ParserTest, ParseParametricProcWithConstantRefParam) {
RoundTrip(R"(proc MyProc<X: u32> {
const DOUBLE_X = u32:2 * X;
c: chan<u32> in;
config(c: chan<u32> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + DOUBLE_X
}
})");
}

TEST_F(ParserTest, ParseProcWithConstantRefFunction) {
RoundTrip(R"(fn double(x: u32) -> u32 {
x * u32:2
}
proc MyProc<X: u32> {
const DOUBLE_X = double(X);
c: chan<u32> in;
config(c: chan<u32> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + DOUBLE_X
}
})");
}

TEST_F(ParserTest, ParseProcWithConstantRefGlobalConstant) {
RoundTrip(R"(const MY_VAL = u32:15;
proc MyProc {
const DOUBLE_VAL = u32:2 * MY_VAL;
c: chan<u32> in;
config(c: chan<u32> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + DOUBLE_VAL
}
})");
}

TEST_F(ParserTest, ParseProcWithConstantInType) {
RoundTrip(R"(const MY_VAL = u32:15;
proc MyProc {
const DOUBLE_VAL = u32:2 * MY_VAL;
c: chan<uN[DOUBLE_VAL]> in;
config(c: chan<uN[DOUBLE_VAL]> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + DOUBLE_VAL
}
})");
}

TEST_F(ParserTest, ParseProcWithConstantRefConstant) {
RoundTrip(R"(proc MyProc {
const MY_VAL = u32:15;
const DOUBLE_VAL = u32:2 * MY_VAL;
c: chan<u32> in;
config(c: chan<u32> in) {
(c,)
}
init {
u32:0
}
next(state: ()) {
state + DOUBLE_VAL
}
})");
}

TEST_F(ParserTest, ParseNextTooManyArgs) {
const char* text = R"(proc confused {
config() { () }
Expand Down
12 changes: 7 additions & 5 deletions xls/dslx/frontend/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class ProcMember : public AstNode {
};

// TODO(leary): 2024-02-09 Extend this to allow for constant definitions, etc.
using ProcStmt = std::variant<Function*, ProcMember*, TypeAlias*, ConstAssert*>;
using ProcStmt = std::variant<Function*, ProcMember*, TypeAlias*, ConstAssert*,
ConstantDef*>;

absl::StatusOr<ProcStmt> ToProcStmt(AstNode* n);

Expand Down Expand Up @@ -138,11 +139,12 @@ class ProcLike : public AstNode {
return result;
}

std::vector<const ConstAssert*> GetConstAssertStmts() const {
std::vector<const ConstAssert*> result;
template <typename T>
std::vector<const T*> GetStmtsOfType() const {
std::vector<const T*> result;
for (const ProcStmt& stmt : stmts()) {
if (std::holds_alternative<ConstAssert*>(stmt)) {
result.push_back(std::get<ConstAssert*>(stmt));
if (std::holds_alternative<T*>(stmt)) {
result.push_back(std::get<T*>(stmt));
}
}
return result;
Expand Down
12 changes: 12 additions & 0 deletions xls/dslx/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,18 @@ dslx_lang_test(
test_autofmt = False,
)

dslx_lang_test(
name = "parametric_proc_with_const",
# No meaningful entry function to convert.
convert_to_ir = False,
)

dslx_lang_test(
name = "proc_with_const",
# No normal 'function' entry point, it is a test proc.
convert_to_ir = False,
)

dslx_lang_test(
name = "proc_two_level",
# No normal 'function' entry point, it is a test proc.
Expand Down
57 changes: 57 additions & 0 deletions xls/dslx/tests/parametric_proc_with_const.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2020 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

proc parametric<N: u32, M: u32> {
const DOUBLE_M = u32:37;
c: chan<uN[DOUBLE_M]> in;
s: chan<uN[M]> out;

config(c: chan<uN[M]> in, s: chan<uN[M]> out) { (c, s) }

init { () }

next(state: ()) {
let (tok, input) = recv(join(), c);
let output = ((input as uN[DOUBLE_M]) * uN[DOUBLE_M]:2) as uN[M];
let tok = send(tok, s, output);
}
}

#[test_proc]
proc test_proc {
terminator: chan<bool> out;
output_c: chan<u37> in;
input_p: chan<u37> out;

config(terminator: chan<bool> out) {
let (input_p, input_c) = chan<u37>("input");
let (output_p, output_c) = chan<u37>("output");
spawn parametric<u32:32, u32:37>(input_c, output_p);
(terminator, output_c, input_p)
}

init { () }

next(state: ()) {
let tok = send(join(), input_p, u37:1);
let (tok, result) = recv(tok, output_c);
assert_eq(result, u37:2);

let tok = send(tok, input_p, u37:8);
let (tok, result) = recv(tok, output_c);
assert_eq(result, u37:16);

let tok = send(tok, terminator, true);
}
}
Loading

0 comments on commit 4bb43f7

Please sign in to comment.