From 96539e43b599c70e79d1c05828c8748c178093dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20D=C3=A4schle?= Date: Mon, 23 Nov 2020 16:16:13 +0100 Subject: [PATCH] all: nested sum types (#6913) --- vlib/v/ast/ast.v | 12 ++-- vlib/v/checker/checker.v | 46 ++++++++---- .../tests/sum_type_assign_non_variant_err.out | 6 ++ .../tests/sum_type_assign_non_variant_err.vv | 12 ++++ vlib/v/gen/cgen.v | 42 +++++++---- vlib/v/table/atypes.v | 12 ---- vlib/v/table/table.v | 3 - vlib/v/tests/union_sum_type_test.v | 71 ++++++++++++++++--- 8 files changed, 146 insertions(+), 58 deletions(-) create mode 100644 vlib/v/checker/tests/sum_type_assign_non_variant_err.out create mode 100644 vlib/v/checker/tests/sum_type_assign_non_variant_err.vv diff --git a/vlib/v/ast/ast.v b/vlib/v/ast/ast.v index 8c7187c03f..295b985244 100644 --- a/vlib/v/ast/ast.v +++ b/vlib/v/ast/ast.v @@ -365,7 +365,7 @@ pub: is_arg bool // fn args should not be autofreed pub mut: typ table.Type - sum_type_cast table.Type + sum_type_casts []table.Type // nested sum types require nested smart casting, for that a list of types is needed pos token.Position is_used bool is_changed bool // to detect mutable vars that are never changed @@ -375,11 +375,11 @@ pub mut: // struct fields change type in scopes pub struct ScopeStructField { pub: - struct_type table.Type // type of struct - name string - pos token.Position - typ table.Type - sum_type_cast table.Type + struct_type table.Type // type of struct + name string + pos token.Position + typ table.Type + sum_type_casts []table.Type // nested sum types require nested smart casting, for that a list of types is needed } pub struct GlobalField { diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index 55b7207c1a..08f2075ba7 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -1801,7 +1801,7 @@ pub fn (mut c Checker) selector_expr(mut selector_expr ast.SelectorExpr) table.T if !prevent_sum_type_unwrapping_once { scope := c.file.scope.innermost(selector_expr.pos.pos) if scope_field := scope.find_struct_field(utyp, field_name) { - return scope_field.sum_type_cast + return scope_field.sum_type_casts.last() } } } @@ -3234,9 +3234,9 @@ pub fn (mut c Checker) ident(mut ident ast.Ident) table.Type { c.error('undefined variable `$ident.name` (used before declaration)', ident.pos) } - is_sum_type_cast := obj.sum_type_cast != 0 && !c.prevent_sum_type_unwrapping_once + is_sum_type_cast := obj.sum_type_casts.len != 0 && !c.prevent_sum_type_unwrapping_once c.prevent_sum_type_unwrapping_once = false - mut typ := if is_sum_type_cast { obj.sum_type_cast } else { obj.typ } + mut typ := if is_sum_type_cast { obj.sum_type_casts.last() } else { obj.typ } if typ == 0 { if mut obj.expr is ast.Ident { if obj.expr.kind == .unresolved { @@ -3561,37 +3561,47 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, type_sym table.TypeSymbol mut scope := c.file.scope.innermost(branch.pos.pos) match mut node.cond { ast.SelectorExpr { + mut is_mut := false + mut sum_type_casts := []table.Type{} expr_sym := c.table.get_type_symbol(node.cond.expr_type) - field := c.table.struct_find_field(expr_sym, node.cond.field_name) or { - table.Field{} + if field := c.table.struct_find_field(expr_sym, node.cond.field_name) { + is_mut = field.is_mut + } + if field := scope.find_struct_field(node.cond.expr_type, node.cond.field_name) { + sum_type_casts << field.sum_type_casts } - is_mut := field.is_mut is_root_mut := scope.is_selector_root_mutable(c.table, node.cond) // smartcast either if the value is immutable or if the mut argument is explicitly given if (!is_root_mut && !is_mut) || node.is_mut { + sum_type_casts << expr_type scope.register_struct_field(ast.ScopeStructField{ struct_type: node.cond.expr_type name: node.cond.field_name typ: node.cond_type - sum_type_cast: expr_type + sum_type_casts: sum_type_casts pos: node.cond.pos }) } } ast.Ident { mut is_mut := false + mut sum_type_casts := []table.Type{} + mut is_already_casted := false if v := scope.find_var(node.cond.name) { is_mut = v.is_mut + sum_type_casts << v.sum_type_casts + is_already_casted = v.pos.pos == node.cond.pos.pos } // smartcast either if the value is immutable or if the mut argument is explicitly given - if !is_mut || node.is_mut { + if (!is_mut || node.is_mut) && !is_already_casted { + sum_type_casts << expr_type scope.register(node.var_name, ast.Var{ name: node.var_name typ: node.cond_type pos: node.cond.pos is_used: true is_mut: node.is_mut - sum_type_cast: expr_type + sum_type_casts: sum_type_casts }) } } @@ -3823,38 +3833,46 @@ pub fn (mut c Checker) if_expr(mut node ast.IfExpr) table.Type { mut is_mut := false mut scope := c.file.scope.innermost(branch.body_pos.pos) if mut infix.left is ast.Ident { + mut sum_type_casts := []table.Type{} if v := scope.find_var(infix.left.name) { is_mut = v.is_mut + sum_type_casts << v.sum_type_casts } // smartcast either if the value is immutable or if the mut argument is explicitly given if (!is_mut || branch.is_mut_name) && left_sym.kind == .union_sum_type { + sum_type_casts << right_expr.typ scope.register(branch.left_as_name, ast.Var{ name: branch.left_as_name typ: infix.left_type - sum_type_cast: right_expr.typ + sum_type_casts: sum_type_casts pos: infix.left.pos is_used: true is_mut: is_mut }) } } else if mut infix.left is ast.SelectorExpr { + mut sum_type_casts := []table.Type{} expr_sym := c.table.get_type_symbol(infix.left.expr_type) - field := c.table.struct_find_field(expr_sym, infix.left.field_name) or { - table.Field{} + if field := c.table.struct_find_field(expr_sym, infix.left.field_name) { + is_mut = field.is_mut + } + if field := scope.find_struct_field(infix.left.expr_type, + infix.left.field_name) { + sum_type_casts << field.sum_type_casts } - is_mut = field.is_mut is_root_mut := scope.is_selector_root_mutable(c.table, infix.left) // smartcast either if the value is immutable or if the mut argument is explicitly given if ((!is_root_mut && !is_mut) || branch.is_mut_name) && left_sym.kind == .union_sum_type { + sum_type_casts << right_expr.typ scope.register_struct_field(ast.ScopeStructField{ struct_type: infix.left.expr_type name: infix.left.field_name typ: infix.left_type - sum_type_cast: right_expr.typ + sum_type_casts: sum_type_casts pos: infix.left.pos }) } diff --git a/vlib/v/checker/tests/sum_type_assign_non_variant_err.out b/vlib/v/checker/tests/sum_type_assign_non_variant_err.out new file mode 100644 index 0000000000..b9a8079f72 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_assign_non_variant_err.out @@ -0,0 +1,6 @@ +vlib/v/checker/tests/sum_type_assign_non_variant_err.vv:11:6: error: cannot assign to `w`: expected `Stmt`, not `IfExpr` + 9 | fn main() { + 10 | mut w := Stmt{} + 11 | w = IfExpr{} + | ~~~~~~~~ + 12 | } \ No newline at end of file diff --git a/vlib/v/checker/tests/sum_type_assign_non_variant_err.vv b/vlib/v/checker/tests/sum_type_assign_non_variant_err.vv new file mode 100644 index 0000000000..cabf2fc306 --- /dev/null +++ b/vlib/v/checker/tests/sum_type_assign_non_variant_err.vv @@ -0,0 +1,12 @@ +__type Expr = IfExpr | CallExpr | MatchExpr +struct MatchExpr {} +struct IfExpr {} +struct CallExpr {} + +__type Stmt = Expr | AnotherThing +struct AnotherThing {} + +fn main() { + mut w := Stmt{} + w = IfExpr{} +} \ No newline at end of file diff --git a/vlib/v/gen/cgen.v b/vlib/v/gen/cgen.v index 0c26b84885..f3d7021556 100644 --- a/vlib/v/gen/cgen.v +++ b/vlib/v/gen/cgen.v @@ -1301,7 +1301,7 @@ fn (mut g Gen) union_expr_with_cast(expr ast.Expr, got_type table.Type, expected scope := g.file.scope.innermost(expr.position().pos) if expr is ast.Ident { if v := scope.find_var(expr.name) { - if v.sum_type_cast != 0 { + if v.sum_type_casts.len > 0 { is_already_sum_type = true } } @@ -2567,12 +2567,17 @@ fn (mut g Gen) expr(node ast.Expr) { scope := g.file.scope.innermost(node.pos.pos) if field := scope.find_struct_field(node.expr_type, node.field_name) { // union sum type deref - g.write('(*') - cast_sym := g.table.get_type_symbol(field.sum_type_cast) - if cast_sym.info is table.Aggregate as sym_info { - sum_type_deref_field = '_${sym_info.types[g.aggregate_type_idx]}' - } else { - sum_type_deref_field = '_$field.sum_type_cast' + for i, typ in field.sum_type_casts { + g.write('(*') + cast_sym := g.table.get_type_symbol(typ) + if i != 0 { + sum_type_deref_field += ').' + } + if cast_sym.info is table.Aggregate as sym_info { + sum_type_deref_field += '_${sym_info.types[g.aggregate_type_idx]}' + } else { + sum_type_deref_field += '_$typ' + } } } } @@ -3403,13 +3408,22 @@ fn (mut g Gen) ident(node ast.Ident) { } scope := g.file.scope.innermost(node.pos.pos) if v := scope.find_var(node.name) { - if v.sum_type_cast != 0 { + if v.sum_type_casts.len > 0 { if !prevent_sum_type_unwrapping_once { - sym := g.table.get_type_symbol(v.sum_type_cast) - if sym.info is table.Aggregate as sym_info { - g.write('(*${name}._${sym_info.types[g.aggregate_type_idx]})') - } else { - g.write('(*${name}._$v.sum_type_cast)') + for _ in v.sum_type_casts { + g.write('(*') + } + for i, typ in v.sum_type_casts { + cast_sym := g.table.get_type_symbol(typ) + if i == 0 { + g.write(name) + } + if cast_sym.info is table.Aggregate as sym_info { + g.write('._${sym_info.types[g.aggregate_type_idx]}') + } else { + g.write('._$typ') + } + g.write(')') } return } @@ -4645,7 +4659,7 @@ fn (mut g Gen) write_types(types []table.TypeSymbol) { } g.type_definitions.writeln('typedef struct {') g.type_definitions.writeln(' union {') - for variant in g.table.get_union_sum_type_variants(it) { + for variant in it.variants { g.type_definitions.writeln(' ${g.typ(variant.to_ptr())} _$variant.idx();') } g.type_definitions.writeln(' };') diff --git a/vlib/v/table/atypes.v b/vlib/v/table/atypes.v index 4b56529abc..8600c2be8e 100644 --- a/vlib/v/table/atypes.v +++ b/vlib/v/table/atypes.v @@ -853,18 +853,6 @@ pub: variants []Type } -pub fn (table &Table) get_union_sum_type_variants(sum_type UnionSumType) []Type { - mut variants := []Type{} - for variant in sum_type.variants { - sym := table.get_type_symbol(variant) - if sym.info is UnionSumType as sym_info { - variants << table.get_union_sum_type_variants(sym_info) - } - variants << variant - } - return variants -} - pub fn (table &Table) type_to_str(t Type) string { sym := table.get_type_symbol(t) mut res := sym.source_name diff --git a/vlib/v/table/table.v b/vlib/v/table/table.v index 5461e6a233..6ee9434716 100644 --- a/vlib/v/table/table.v +++ b/vlib/v/table/table.v @@ -738,9 +738,6 @@ pub fn (table &Table) sumtype_has_variant(parent Type, variant Type) bool { if v.idx() == variant.idx() { return true } - if table.sumtype_has_variant(v, variant) { - return true - } } } return false diff --git a/vlib/v/tests/union_sum_type_test.v b/vlib/v/tests/union_sum_type_test.v index fb7fd5d1c5..9b6cef92d9 100644 --- a/vlib/v/tests/union_sum_type_test.v +++ b/vlib/v/tests/union_sum_type_test.v @@ -106,15 +106,15 @@ fn test_converting_down() { assert res[1].name == 'three' } -fn test_nested_sumtype() { - mut a := Node{} - mut b := Node{} - a = StructDecl{pos: 1} - b = IfExpr{pos: 1} - c := Node(Expr(IfExpr{pos:1})) - if c is Expr { - if c is IfExpr { - assert true +struct NodeWrapper { + node Node +} + +fn test_nested_sumtype_selector() { + c := NodeWrapper{Node(Expr(IfExpr{pos: 1}))} + if c.node is Expr { + if c.node is IfExpr { + assert c.node.pos == 1 } else { assert false @@ -125,6 +125,59 @@ fn test_nested_sumtype() { } } +fn test_nested_sumtype_match_selector() { + c := NodeWrapper{Node(Expr(IfExpr{pos: 1}))} + match union c.node { + Expr { + match union c.node { + IfExpr { + assert c.node.pos == 1 + } + else { + assert false + } + } + } + else { + assert false + } + } +} + +fn test_nested_sumtype() { + c := Node(Expr(IfExpr{pos:1})) + if c is Expr { + if c is IfExpr { + assert c.pos == 1 + } + else { + assert false + } + } + else { + assert false + } +} + +fn test_nested_sumtype_match() { + c := Node(Expr(IfExpr{pos: 1})) + match union c { + Expr { + match union c { + IfExpr { + assert c.pos == 1 + } + else { + assert false + } + } + } + else { + assert false + } + } +} + __type Abc = int | string fn test_string_cast_to_sumtype() {