checker, cgen: fix match with complex sumtype exprs (#10215)
parent
2abbbcc02d
commit
7c0cd2f41d
|
@ -5358,16 +5358,23 @@ pub fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type {
|
||||||
}
|
}
|
||||||
expr_type := c.expr(stmt.expr)
|
expr_type := c.expr(stmt.expr)
|
||||||
if ret_type == ast.void_type {
|
if ret_type == ast.void_type {
|
||||||
|
if node.is_expr
|
||||||
|
&& c.table.get_type_symbol(node.expected_type).kind == .sum_type {
|
||||||
|
ret_type = node.expected_type
|
||||||
|
} else {
|
||||||
ret_type = expr_type
|
ret_type = expr_type
|
||||||
stmt.typ = ret_type
|
}
|
||||||
|
stmt.typ = expr_type
|
||||||
} else if node.is_expr && ret_type != expr_type {
|
} else if node.is_expr && ret_type != expr_type {
|
||||||
if !c.check_types(ret_type, expr_type) {
|
if !c.check_types(ret_type, expr_type) {
|
||||||
ret_sym := c.table.get_type_symbol(ret_type)
|
ret_sym := c.table.get_type_symbol(ret_type)
|
||||||
|
if !(node.is_expr && ret_sym.kind == .sum_type) {
|
||||||
c.error('return type mismatch, it should be `$ret_sym.name`',
|
c.error('return type mismatch, it should be `$ret_sym.name`',
|
||||||
stmt.expr.position())
|
stmt.expr.position())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
// TODO: ask alex about this
|
// TODO: ask alex about this
|
||||||
// typ := c.expr(stmt.expr)
|
// typ := c.expr(stmt.expr)
|
||||||
|
@ -5995,6 +6002,10 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if node.is_expr
|
||||||
|
&& c.table.get_type_symbol(former_expected_type).kind == .sum_type {
|
||||||
|
continue
|
||||||
|
}
|
||||||
c.error('mismatched types `${c.table.type_to_str(node.typ)}` and `${c.table.type_to_str(last_expr.typ)}`',
|
c.error('mismatched types `${c.table.type_to_str(node.typ)}` and `${c.table.type_to_str(last_expr.typ)}`',
|
||||||
node.pos)
|
node.pos)
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,6 +173,7 @@ mut:
|
||||||
as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast)
|
as_cast_type_names map[string]string // table for type name lookup in runtime (for __as_cast)
|
||||||
obf_table map[string]string
|
obf_table map[string]string
|
||||||
// main_fn_decl_node ast.FnDecl
|
// main_fn_decl_node ast.FnDecl
|
||||||
|
expected_cast_type ast.Type // for match expr of sumtypes
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
|
pub fn gen(files []&ast.File, table &ast.Table, pref &pref.Preferences) string {
|
||||||
|
@ -1223,7 +1224,11 @@ fn (mut g Gen) stmt(node ast.Stmt) {
|
||||||
// }
|
// }
|
||||||
old_is_void_expr_stmt := g.is_void_expr_stmt
|
old_is_void_expr_stmt := g.is_void_expr_stmt
|
||||||
g.is_void_expr_stmt = !node.is_expr
|
g.is_void_expr_stmt = !node.is_expr
|
||||||
|
if node.typ != ast.void_type && g.expected_cast_type != 0 {
|
||||||
|
g.expr_with_cast(node.expr, node.typ, g.expected_cast_type)
|
||||||
|
} else {
|
||||||
g.expr(node.expr)
|
g.expr(node.expr)
|
||||||
|
}
|
||||||
g.is_void_expr_stmt = old_is_void_expr_stmt
|
g.is_void_expr_stmt = old_is_void_expr_stmt
|
||||||
// if af {
|
// if af {
|
||||||
// g.autofree_call_postgen()
|
// g.autofree_call_postgen()
|
||||||
|
@ -4131,7 +4136,12 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str
|
||||||
g.writeln(') {')
|
g.writeln(') {')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if is_expr && tmp_var.len > 0
|
||||||
|
&& g.table.get_type_symbol(node.return_type).kind == .sum_type {
|
||||||
|
g.expected_cast_type = node.return_type
|
||||||
|
}
|
||||||
g.stmts_with_tmp_var(branch.stmts, tmp_var)
|
g.stmts_with_tmp_var(branch.stmts, tmp_var)
|
||||||
|
g.expected_cast_type = 0
|
||||||
if g.inside_ternary == 0 {
|
if g.inside_ternary == 0 {
|
||||||
g.write('}')
|
g.write('}')
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
struct Empty {}
|
||||||
|
|
||||||
|
struct Node {
|
||||||
|
value f64
|
||||||
|
left Tree
|
||||||
|
right Tree
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tree = Empty | Node
|
||||||
|
|
||||||
|
// return size(number of nodes) of BST
|
||||||
|
fn size(tree Tree) int {
|
||||||
|
return match tree {
|
||||||
|
// TODO: remove int() once match gets smarter
|
||||||
|
Empty { int(0) }
|
||||||
|
Node { 1 + size(tree.left) + size(tree.right) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert a value to BST
|
||||||
|
fn insert(tree Tree, x f64) Tree {
|
||||||
|
return match tree {
|
||||||
|
Empty {
|
||||||
|
Node{x, tree, tree}
|
||||||
|
}
|
||||||
|
Node {
|
||||||
|
if x == tree.value {
|
||||||
|
tree
|
||||||
|
} else if x < tree.value {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
left: insert(tree.left, x)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
right: insert(tree.right, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// whether able to find a value in BST
|
||||||
|
fn search(tree Tree, x f64) bool {
|
||||||
|
return match tree {
|
||||||
|
Empty {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Node {
|
||||||
|
if x == tree.value {
|
||||||
|
true
|
||||||
|
} else if x < tree.value {
|
||||||
|
search(tree.left, x)
|
||||||
|
} else {
|
||||||
|
search(tree.right, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the minimal value of a BST
|
||||||
|
fn min(tree Tree) f64 {
|
||||||
|
return match tree {
|
||||||
|
Empty {
|
||||||
|
1e100
|
||||||
|
}
|
||||||
|
Node {
|
||||||
|
if tree.value < min(tree.left) {
|
||||||
|
tree.value
|
||||||
|
} else {
|
||||||
|
min(tree.left)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete a value in BST (if nonexistant do nothing)
|
||||||
|
fn delete(tree Tree, x f64) Tree {
|
||||||
|
return match tree {
|
||||||
|
Empty {
|
||||||
|
tree
|
||||||
|
}
|
||||||
|
Node {
|
||||||
|
if tree.left is Node && tree.right is Node {
|
||||||
|
if x < tree.value {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
left: delete(tree.left, x)
|
||||||
|
}
|
||||||
|
} else if x > tree.value {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
right: delete(tree.right, x)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
value: min(tree.right)
|
||||||
|
right: delete(tree.right, min(tree.right))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if tree.left is Node {
|
||||||
|
if x == tree.value {
|
||||||
|
tree.left
|
||||||
|
} else {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
left: delete(tree.left, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if x == tree.value {
|
||||||
|
tree.right
|
||||||
|
} else {
|
||||||
|
Node{
|
||||||
|
...tree
|
||||||
|
right: delete(tree.right, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_match_with_complex_sumtype_exprs() {
|
||||||
|
mut tree := Tree(Empty{})
|
||||||
|
input := [0.3, 0.2, 0.5, 0.0, 0.6, 0.8, 0.9, 1.0, 0.1, 0.4, 0.7]
|
||||||
|
for i in input {
|
||||||
|
tree = insert(tree, i)
|
||||||
|
}
|
||||||
|
print('[1] after insertion tree size is ') // 11
|
||||||
|
println(size(tree))
|
||||||
|
del := [-0.3, 0.0, 0.3, 0.6, 1.0, 1.5]
|
||||||
|
for i in del {
|
||||||
|
tree = delete(tree, i)
|
||||||
|
}
|
||||||
|
print('[2] after deletion tree size is ') // 7
|
||||||
|
print(size(tree))
|
||||||
|
print(', and these elements were deleted: ') // 0.0 0.3 0.6 1.0
|
||||||
|
assert size(tree) == 7
|
||||||
|
for i in input {
|
||||||
|
if !search(tree, i) {
|
||||||
|
print(i)
|
||||||
|
print(' ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println('')
|
||||||
|
assert true
|
||||||
|
}
|
Loading…
Reference in New Issue